Unverified Commit d01a62ef authored by Jinhua Zhu's avatar Jinhua Zhu Committed by GitHub
Browse files

valid with ema (#16)



* valid with ema

* save ema model

* update ema

* rm model_fp32

* rm ordered names

* load ema

* emapy

* Update unicore/trainer.py
Co-authored-by: default avatarGuolin Ke <guolin.ke@outlook.com>
parent d1dc264d
from copy import deepcopy
from itertools import chain
from unicore.optim.fp16_optimizer import pad_numel
import torch
class ExponentialMovingAverageModel:
def __init__(self, model, decay, init_param=None):
self.model_ema = deepcopy(model).float()
self.decay = decay
self.param = self.flatten_parameters(model, init_param)
def flatten_parameters(self, model, init_param):
# get ordered name
dtype_grouped_names = dict()
ordered_dtype = []
for n, p in model.named_parameters():
if p.dtype not in dtype_grouped_names:
dtype_grouped_names[p.dtype] = []
ordered_dtype.append(p.dtype)
dtype_grouped_names[p.dtype].append(n)
ordered_names = list(chain(*(dtype_grouped_names[n] for n in ordered_dtype)))
name2param = dict()
for n, p in self.model_ema.named_parameters():
name2param[n] = p
cur_params = [name2param[n] for n in ordered_names]
total_param_size = sum(pad_numel(p.data.numel()) for p in cur_params)
flatten_param = cur_params[0].new(0).float().new_zeros(total_param_size)
offset = 0
for p in cur_params:
numel = p.data.numel()
flatten_param[offset : offset + numel].copy_(p.data.view(-1))
p.data = flatten_param.data[offset : offset + numel].view(*p.shape)
offset += pad_numel(numel)
flatten_param = torch.nn.Parameter(flatten_param)
if init_param is not None:
assert torch.allclose(init_param, flatten_param), "ema init error!"
torch.cuda.empty_cache()
return flatten_param
def update(self, new_param):
with torch.no_grad():
diff = self.param - new_param
diff *= 1 - self.decay
self.param -= diff
def load_state_dict(self, state_dict):
self.model_ema.load_state_dict(state_dict["params"])
self.decay = state_dict["decay"] if "decay" in state_dict else self.decay
def state_dict(self):
return {
"params": self.model_ema.state_dict(),
"decay": self.decay,
}
...@@ -35,7 +35,7 @@ class _FP16OptimizerMixin(object): ...@@ -35,7 +35,7 @@ class _FP16OptimizerMixin(object):
def build_fp32_params(cls, args, params): def build_fp32_params(cls, args, params):
# create FP32 copy of parameters and grads # create FP32 copy of parameters and grads
total_param_size = sum([p.data.numel() for p in params]) total_param_size = sum([p.data.numel() for p in params])
fp32_params = params[0].new(0).float().new(total_param_size) fp32_params = params[0].new(0).float().new_zeros(total_param_size)
offset = 0 offset = 0
for p in params: for p in params:
numel = p.data.numel() numel = p.data.numel()
...@@ -48,9 +48,11 @@ class _FP16OptimizerMixin(object): ...@@ -48,9 +48,11 @@ class _FP16OptimizerMixin(object):
@classmethod @classmethod
def flatten_fp16_parameters(cls, args, params): def flatten_fp16_parameters(cls, args, params):
dtype_grouped_params = {} dtype_grouped_params = {}
ordered_dtype = [] # for sort dtype
for p in params: for p in params:
if p.dtype not in dtype_grouped_params: if p.dtype not in dtype_grouped_params:
dtype_grouped_params[p.dtype] = [] dtype_grouped_params[p.dtype] = []
ordered_dtype.append(p.dtype)
dtype_grouped_params[p.dtype].append(p) dtype_grouped_params[p.dtype].append(p)
flatten_params = {} flatten_params = {}
...@@ -58,7 +60,7 @@ class _FP16OptimizerMixin(object): ...@@ -58,7 +60,7 @@ class _FP16OptimizerMixin(object):
cur_params = dtype_grouped_params[dtype] cur_params = dtype_grouped_params[dtype]
total_param_size = sum(pad_numel(p.data.numel()) for p in cur_params) total_param_size = sum(pad_numel(p.data.numel()) for p in cur_params)
flatten_params[dtype] = ( flatten_params[dtype] = (
cur_params[0].new(0).type(dtype).new(total_param_size) cur_params[0].new(0).type(dtype).new_zeros(total_param_size)
) )
offset = 0 offset = 0
for p in cur_params: for p in cur_params:
...@@ -80,7 +82,7 @@ class _FP16OptimizerMixin(object): ...@@ -80,7 +82,7 @@ class _FP16OptimizerMixin(object):
) )
offset += pad_numel(numel) offset += pad_numel(numel)
torch.cuda.empty_cache() torch.cuda.empty_cache()
return list(flatten_params.values()) return [flatten_params[dtype] for dtype in ordered_dtype]
def state_dict(self): def state_dict(self):
"""Return the optimizer's state dict.""" """Return the optimizer's state dict."""
...@@ -91,7 +93,6 @@ class _FP16OptimizerMixin(object): ...@@ -91,7 +93,6 @@ class _FP16OptimizerMixin(object):
def load_state_dict(self, state_dict, optimizer_overrides=None): def load_state_dict(self, state_dict, optimizer_overrides=None):
"""Load an optimizer state dict. """Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of allows us to resume training from a checkpoint using a new set of
...@@ -103,7 +104,6 @@ class _FP16OptimizerMixin(object): ...@@ -103,7 +104,6 @@ class _FP16OptimizerMixin(object):
def backward(self, loss): def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves. """Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Compared to :func:`unicore.optim.UnicoreOptimizer.backward`, this Compared to :func:`unicore.optim.UnicoreOptimizer.backward`, this
function additionally dynamically scales the loss to avoid gradient function additionally dynamically scales the loss to avoid gradient
underflow. underflow.
......
...@@ -194,6 +194,7 @@ def get_parser(desc, default_task='test'): ...@@ -194,6 +194,7 @@ def get_parser(desc, default_task='test'):
"main method can return a value (useful for sweeps)") "main method can return a value (useful for sweeps)")
parser.add_argument('--profile', action='store_true', help="enable autograd profiler emit_nvtx") parser.add_argument('--profile', action='store_true', help="enable autograd profiler emit_nvtx")
parser.add_argument('--ema-decay', default=-1.0, type=float, help="enable moving average for model weights") parser.add_argument('--ema-decay', default=-1.0, type=float, help="enable moving average for model weights")
parser.add_argument("--validate-with-ema", default=False, action="store_true")
from unicore.registry import REGISTRIES from unicore.registry import REGISTRIES
......
...@@ -15,78 +15,18 @@ import sys ...@@ -15,78 +15,18 @@ import sys
import time import time
from itertools import chain from itertools import chain
from typing import Any, Dict, List from typing import Any, Dict, List
import torch import torch
from unicore import checkpoint_utils, models, optim, utils from unicore import checkpoint_utils, models, optim, utils
from unicore.distributed import utils as distributed_utils from unicore.distributed import utils as distributed_utils
from unicore.logging import meters, metrics from unicore.logging import meters, metrics
from unicore.nan_detector import NanDetector from unicore.nan_detector import NanDetector
from unicore.optim import lr_scheduler from unicore.optim import lr_scheduler
from unicore.utils import tensor_tree_map from unicore.ema import ExponentialMovingAverageModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ExponentialMovingAverage:
"""
Maintains moving averages of parameters with exponential decay
At each step, the stored copy `copy` of each parameter `param` is
updated as follows:
`copy = decay * copy + (1 - decay) * param`
where `decay` is an attribute of the ExponentialMovingAverage object.
"""
def __init__(self, model: torch.nn.Module, decay: float):
"""
Args:
model:
A torch.nn.Module whose parameters are to be tracked
decay:
A value (usually close to 1.) by which updates are
weighted as part of the above formula
"""
super(ExponentialMovingAverage, self).__init__()
with torch.no_grad():
clone_param = lambda t: t.clone().detach().float()
self.params = tensor_tree_map(clone_param, model.state_dict())
self.decay = decay
def _update_state_dict_(self, update, state_dict):
with torch.no_grad():
for k, v in update.items():
if state_dict[k].device != v.device:
state_dict[k] = state_dict[k].to(v.device)
stored = state_dict[k]
if not isinstance(v, torch.Tensor):
self._update_state_dict_(v, stored)
else:
diff = stored - v.float()
diff *= 1 - self.decay
stored -= diff
def update(self, model: torch.nn.Module) -> None:
"""
Updates the stored parameters using the state dict of the provided
module. The module should have the same structure as that used to
initialize the ExponentialMovingAverage object.
"""
self._update_state_dict_(model.state_dict(), self.params)
def load_state_dict(self, state_dict: dict) -> None:
self.params = state_dict["params"]
self.decay = state_dict["decay"] if "decay" in state_dict else self.decay
def state_dict(self) -> dict:
return {
"params": self.params,
"decay": self.decay,
}
class Trainer(object): class Trainer(object):
"""Main class for data parallel training. """Main class for data parallel training.
...@@ -167,8 +107,21 @@ class Trainer(object): ...@@ -167,8 +107,21 @@ class Trainer(object):
self.cuda_env_arr = None self.cuda_env_arr = None
# add ema # add ema
if args.ema_decay > 0 and self.data_parallel_rank == 0: if args.validate_with_ema:
self.ema = ExponentialMovingAverage(self._model, decay=args.ema_decay) assert args.ema_decay > 0, "valid with ema must with ema_decay > 0"
if args.ema_decay > 0 and (
self.data_parallel_rank == 0 or args.validate_with_ema
):
assert isinstance(
self.optimizer, optim.FP16Optimizer
), "ema must with fp16 optimizer"
self.ema = ExponentialMovingAverageModel(
model,
args.ema_decay,
self._optimizer.fp32_params,
)
else: else:
self.ema = None self.ema = None
metrics.log_start_time("wall", priority=790, round=2) metrics.log_start_time("wall", priority=790, round=2)
...@@ -434,7 +387,9 @@ class Trainer(object): ...@@ -434,7 +387,9 @@ class Trainer(object):
logger.info( logger.info(
f"Cannot find EMA state in checkpoint, load model weight to ema directly" f"Cannot find EMA state in checkpoint, load model weight to ema directly"
) )
self.ema = ExponentialMovingAverage(self._model, decay=self.ema.decay) self.ema = ExponentialMovingAverageModel(
self._model, decay=self.ema.decay
)
if last_optim_state is not None and not reset_optimizer: if last_optim_state is not None and not reset_optimizer:
# rebuild optimizer after loading model, since params may have changed # rebuild optimizer after loading model, since params may have changed
...@@ -730,7 +685,7 @@ class Trainer(object): ...@@ -730,7 +685,7 @@ class Trainer(object):
) )
if self.ema is not None: if self.ema is not None:
with torch.autograd.profiler.record_function("ema"): with torch.autograd.profiler.record_function("ema"):
self.ema.update(self.model) self.ema.update(self.optimizer.fp32_params)
except FloatingPointError: except FloatingPointError:
# re-run the forward and backward pass with hooks attached to print # re-run the forward and backward pass with hooks attached to print
......
...@@ -420,3 +420,17 @@ def set_jit_fusion_options(): ...@@ -420,3 +420,17 @@ def set_jit_fusion_options():
torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True) torch._C._jit_override_can_fuse_on_gpu(True)
@contextlib.contextmanager
def validate_with_ema(trainer, ema=False):
if not ema:
yield
return
_wrapped_model = trainer._wrapped_model
trainer._wrapped_model = trainer.ema.model_ema
try:
yield
finally:
trainer._wrapped_model = _wrapped_model
\ No newline at end of file
...@@ -302,7 +302,8 @@ def validate_and_save( ...@@ -302,7 +302,8 @@ def validate_and_save(
# Validate # Validate
valid_losses = [None] valid_losses = [None]
if do_validate: if do_validate:
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) with utils.validate_with_ema(trainer, ema=args.validate_with_ema):
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
should_stop |= should_stop_early(args, valid_losses[0]) should_stop |= should_stop_early(args, valid_losses[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment