from contextlib import contextmanager import logging import os, sys from termcolor import colored import copy import numpy as np import torch class MagnitudeRecover(): def __init__(self, model, reg=1e-3): self.rec = {} self.reg = reg self.cnt = 0 with torch.no_grad(): for name, p in model.named_parameters(): norm = p.pow(2).mean() self.rec[name] = norm def regularize(self, model): with torch.no_grad(): for name, p in model.named_parameters(): if name in self.rec: target_norm = self.rec[name] if p.data.pow(2).mean() > target_norm: self.rec.pop(name) continue p.grad.data+= -self.reg * p.data if self.cnt%1000==0: print(name, p.pow(2).mean(), target_norm) self.cnt+=1 def flatten_dict(dic): flattned = dict() def _flatten(prefix, d): for k, v in d.items(): if isinstance(v, dict): if prefix is None: _flatten( k, v ) else: _flatten( prefix+'/%s'%k, v ) else: if prefix is None: flattned[k] = v else: flattned[ prefix+'/%s'%k ] = v _flatten(None, dic) return flattned class _ColorfulFormatter(logging.Formatter): def __init__(self, *args, **kwargs): super(_ColorfulFormatter, self).__init__(*args, **kwargs) def formatMessage(self, record): log = super(_ColorfulFormatter, self).formatMessage(record) if record.levelno == logging.WARNING: prefix = colored("WARNING", "yellow", attrs=["blink"]) elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: prefix = colored("ERROR", "red", attrs=["blink", "underline"]) else: return log return prefix + " " + log def get_logger(name='train', output=None, color=True): logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) logger.propagate = False # STDOUT stdout_handler = logging.StreamHandler( stream=sys.stdout ) stdout_handler.setLevel( logging.DEBUG ) plain_formatter = logging.Formatter( "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" ) if color: formatter = _ColorfulFormatter( colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", datefmt="%m/%d %H:%M:%S") else: formatter = plain_formatter stdout_handler.setFormatter(formatter) logger.addHandler(stdout_handler) # FILE if output is not None: if output.endswith('.txt') or output.endswith('.log'): os.makedirs(os.path.dirname(output), exist_ok=True) filename = output else: os.makedirs(output, exist_ok=True) filename = os.path.join(output, "log.txt") file_handler = logging.FileHandler(filename) file_handler.setFormatter(plain_formatter) file_handler.setLevel(logging.DEBUG) logger.addHandler(file_handler) return logger