Unverified Commit d01a62ef authored by Jinhua Zhu's avatar Jinhua Zhu Committed by GitHub
Browse files

valid with ema (#16)



* valid with ema

* save ema model

* update ema

* rm model_fp32

* rm ordered names

* load ema

* emapy

* Update unicore/trainer.py
Co-authored-by: default avatarGuolin Ke <guolin.ke@outlook.com>
parent d1dc264d
from copy import deepcopy
from itertools import chain
from unicore.optim.fp16_optimizer import pad_numel
import torch
class ExponentialMovingAverageModel:
def __init__(self, model, decay, init_param=None):
self.model_ema = deepcopy(model).float()
self.decay = decay
self.param = self.flatten_parameters(model, init_param)
def flatten_parameters(self, model, init_param):
# get ordered name
dtype_grouped_names = dict()
ordered_dtype = []
for n, p in model.named_parameters():
if p.dtype not in dtype_grouped_names:
dtype_grouped_names[p.dtype] = []
ordered_dtype.append(p.dtype)
dtype_grouped_names[p.dtype].append(n)
ordered_names = list(chain(*(dtype_grouped_names[n] for n in ordered_dtype)))
name2param = dict()
for n, p in self.model_ema.named_parameters():
name2param[n] = p
cur_params = [name2param[n] for n in ordered_names]
total_param_size = sum(pad_numel(p.data.numel()) for p in cur_params)
flatten_param = cur_params[0].new(0).float().new_zeros(total_param_size)
offset = 0
for p in cur_params:
numel = p.data.numel()
flatten_param[offset : offset + numel].copy_(p.data.view(-1))
p.data = flatten_param.data[offset : offset + numel].view(*p.shape)
offset += pad_numel(numel)
flatten_param = torch.nn.Parameter(flatten_param)
if init_param is not None:
assert torch.allclose(init_param, flatten_param), "ema init error!"
torch.cuda.empty_cache()
return flatten_param
def update(self, new_param):
with torch.no_grad():
diff = self.param - new_param
diff *= 1 - self.decay
self.param -= diff
def load_state_dict(self, state_dict):
self.model_ema.load_state_dict(state_dict["params"])
self.decay = state_dict["decay"] if "decay" in state_dict else self.decay
def state_dict(self):
return {
"params": self.model_ema.state_dict(),
"decay": self.decay,
}
......@@ -35,7 +35,7 @@ class _FP16OptimizerMixin(object):
def build_fp32_params(cls, args, params):
# create FP32 copy of parameters and grads
total_param_size = sum([p.data.numel() for p in params])
fp32_params = params[0].new(0).float().new(total_param_size)
fp32_params = params[0].new(0).float().new_zeros(total_param_size)
offset = 0
for p in params:
numel = p.data.numel()
......@@ -48,9 +48,11 @@ class _FP16OptimizerMixin(object):
@classmethod
def flatten_fp16_parameters(cls, args, params):
dtype_grouped_params = {}
ordered_dtype = [] # for sort dtype
for p in params:
if p.dtype not in dtype_grouped_params:
dtype_grouped_params[p.dtype] = []
ordered_dtype.append(p.dtype)
dtype_grouped_params[p.dtype].append(p)
flatten_params = {}
......@@ -58,7 +60,7 @@ class _FP16OptimizerMixin(object):
cur_params = dtype_grouped_params[dtype]
total_param_size = sum(pad_numel(p.data.numel()) for p in cur_params)
flatten_params[dtype] = (
cur_params[0].new(0).type(dtype).new(total_param_size)
cur_params[0].new(0).type(dtype).new_zeros(total_param_size)
)
offset = 0
for p in cur_params:
......@@ -80,7 +82,7 @@ class _FP16OptimizerMixin(object):
)
offset += pad_numel(numel)
torch.cuda.empty_cache()
return list(flatten_params.values())
return [flatten_params[dtype] for dtype in ordered_dtype]
def state_dict(self):
"""Return the optimizer's state dict."""
......@@ -91,7 +93,6 @@ class _FP16OptimizerMixin(object):
def load_state_dict(self, state_dict, optimizer_overrides=None):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
......@@ -103,7 +104,6 @@ class _FP16OptimizerMixin(object):
def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Compared to :func:`unicore.optim.UnicoreOptimizer.backward`, this
function additionally dynamically scales the loss to avoid gradient
underflow.
......
......@@ -194,6 +194,7 @@ def get_parser(desc, default_task='test'):
"main method can return a value (useful for sweeps)")
parser.add_argument('--profile', action='store_true', help="enable autograd profiler emit_nvtx")
parser.add_argument('--ema-decay', default=-1.0, type=float, help="enable moving average for model weights")
parser.add_argument("--validate-with-ema", default=False, action="store_true")
from unicore.registry import REGISTRIES
......
......@@ -15,78 +15,18 @@ import sys
import time
from itertools import chain
from typing import Any, Dict, List
import torch
from unicore import checkpoint_utils, models, optim, utils
from unicore.distributed import utils as distributed_utils
from unicore.logging import meters, metrics
from unicore.nan_detector import NanDetector
from unicore.optim import lr_scheduler
from unicore.utils import tensor_tree_map
from unicore.ema import ExponentialMovingAverageModel
logger = logging.getLogger(__name__)
class ExponentialMovingAverage:
"""
Maintains moving averages of parameters with exponential decay
At each step, the stored copy `copy` of each parameter `param` is
updated as follows:
`copy = decay * copy + (1 - decay) * param`
where `decay` is an attribute of the ExponentialMovingAverage object.
"""
def __init__(self, model: torch.nn.Module, decay: float):
"""
Args:
model:
A torch.nn.Module whose parameters are to be tracked
decay:
A value (usually close to 1.) by which updates are
weighted as part of the above formula
"""
super(ExponentialMovingAverage, self).__init__()
with torch.no_grad():
clone_param = lambda t: t.clone().detach().float()
self.params = tensor_tree_map(clone_param, model.state_dict())
self.decay = decay
def _update_state_dict_(self, update, state_dict):
with torch.no_grad():
for k, v in update.items():
if state_dict[k].device != v.device:
state_dict[k] = state_dict[k].to(v.device)
stored = state_dict[k]
if not isinstance(v, torch.Tensor):
self._update_state_dict_(v, stored)
else:
diff = stored - v.float()
diff *= 1 - self.decay
stored -= diff
def update(self, model: torch.nn.Module) -> None:
"""
Updates the stored parameters using the state dict of the provided
module. The module should have the same structure as that used to
initialize the ExponentialMovingAverage object.
"""
self._update_state_dict_(model.state_dict(), self.params)
def load_state_dict(self, state_dict: dict) -> None:
self.params = state_dict["params"]
self.decay = state_dict["decay"] if "decay" in state_dict else self.decay
def state_dict(self) -> dict:
return {
"params": self.params,
"decay": self.decay,
}
class Trainer(object):
"""Main class for data parallel training.
......@@ -167,8 +107,21 @@ class Trainer(object):
self.cuda_env_arr = None
# add ema
if args.ema_decay > 0 and self.data_parallel_rank == 0:
self.ema = ExponentialMovingAverage(self._model, decay=args.ema_decay)
if args.validate_with_ema:
assert args.ema_decay > 0, "valid with ema must with ema_decay > 0"
if args.ema_decay > 0 and (
self.data_parallel_rank == 0 or args.validate_with_ema
):
assert isinstance(
self.optimizer, optim.FP16Optimizer
), "ema must with fp16 optimizer"
self.ema = ExponentialMovingAverageModel(
model,
args.ema_decay,
self._optimizer.fp32_params,
)
else:
self.ema = None
metrics.log_start_time("wall", priority=790, round=2)
......@@ -434,7 +387,9 @@ class Trainer(object):
logger.info(
f"Cannot find EMA state in checkpoint, load model weight to ema directly"
)
self.ema = ExponentialMovingAverage(self._model, decay=self.ema.decay)
self.ema = ExponentialMovingAverageModel(
self._model, decay=self.ema.decay
)
if last_optim_state is not None and not reset_optimizer:
# rebuild optimizer after loading model, since params may have changed
......@@ -730,7 +685,7 @@ class Trainer(object):
)
if self.ema is not None:
with torch.autograd.profiler.record_function("ema"):
self.ema.update(self.model)
self.ema.update(self.optimizer.fp32_params)
except FloatingPointError:
# re-run the forward and backward pass with hooks attached to print
......
......@@ -420,3 +420,17 @@ def set_jit_fusion_options():
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
@contextlib.contextmanager
def validate_with_ema(trainer, ema=False):
if not ema:
yield
return
_wrapped_model = trainer._wrapped_model
trainer._wrapped_model = trainer.ema.model_ema
try:
yield
finally:
trainer._wrapped_model = _wrapped_model
\ No newline at end of file
......@@ -302,7 +302,8 @@ def validate_and_save(
# Validate
valid_losses = [None]
if do_validate:
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
with utils.validate_with_ema(trainer, ema=args.validate_with_ema):
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
should_stop |= should_stop_early(args, valid_losses[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment