Commit 21cb6b39 authored by Guolin Ke's avatar Guolin Ke
Browse files

warning about missing keys in loading model

parent 8da5eaaf
......@@ -27,6 +27,7 @@ from unicore.utils import tensor_tree_map
logger = logging.getLogger(__name__)
class ExponentialMovingAverage:
"""
Maintains moving averages of parameters with exponential decay
......@@ -164,7 +165,7 @@ class Trainer(object):
else:
self.cuda_env = None
self.cuda_env_arr = None
# add ema
if args.ema_decay > 0 and self.data_parallel_rank == 0:
self.ema = ExponentialMovingAverage(self._model, decay=args.ema_decay)
......@@ -207,9 +208,7 @@ class Trainer(object):
@property
def use_distributed_wrapper(self) -> bool:
return (
self.data_parallel_world_size > 1
)
return self.data_parallel_world_size > 1
@property
def should_save_checkpoint_on_current_rank(self) -> bool:
......@@ -224,10 +223,7 @@ class Trainer(object):
@property
def loss(self):
if self._wrapped_loss is None:
if (
utils.has_parameters(self._loss)
and self.use_distributed_wrapper
):
if utils.has_parameters(self._loss) and self.use_distributed_wrapper:
self._wrapped_loss = models.DistributedUnicoreModel(
self.args,
self._loss,
......@@ -281,7 +277,7 @@ class Trainer(object):
"please switch to FP32 which is likely to be faster"
)
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
if self.args.allreduce_fp32_grad:
assert self.args.ddp_backend == "no_c10d"
if self.args.per_sample_clip_norm > 0:
......@@ -290,7 +286,7 @@ class Trainer(object):
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
logger.info("NOTE: your device may support faster training with --fp16")
self._optimizer = optim.build_optimizer(self.args, params)
# We should initialize the learning rate scheduler immediately after
# building the optimizer, so that the initial learning rate is set.
self._lr_scheduler = lr_scheduler.build_lr_scheduler(
......@@ -305,8 +301,7 @@ class Trainer(object):
"args": self.args,
"model": self.model.state_dict(),
"loss": (
self.loss.state_dict()
if utils.has_parameters(self.loss) else None
self.loss.state_dict() if utils.has_parameters(self.loss) else None
),
"optimizer_history": (self._optim_history or [])
+ [
......@@ -321,7 +316,7 @@ class Trainer(object):
"extra_state": {
"metrics": metrics.state_dict(),
"previous_training_time": self.cumulative_training_time(),
}
},
}
if not self.args.no_save_optimizer_state:
state_dict["last_optimizer_state"] = self.optimizer.state_dict()
......@@ -375,7 +370,7 @@ class Trainer(object):
state = None
if is_master:
state = checkpoint_utils.load_checkpoint_to_cpu(
filename,
filename,
)
if is_distributed:
logger.info("Broadcast checkpoint from rank_0")
......@@ -392,19 +387,28 @@ class Trainer(object):
try:
if self.args.load_from_ema:
logger.info("loading ema state to model")
self.model.load_state_dict(
errors = self.model.load_state_dict(
ema_state["params"], strict=False, model_args=self.args
)
else:
self.model.load_state_dict(
errors = self.model.load_state_dict(
state["model"], strict=False, model_args=self.args
)
# save memory for later steps
del state["model"]
if utils.has_parameters(self.get_loss()):
self.get_loss().load_state_dict(
state["loss"], strict=True
if errors.missing_keys:
logger.warning(
"Error in loading model state, missing_keys " +
str(errors.missing_keys)
)
if errors.unexpected_keys:
logger.warning(
"Error in loading model state, unexpected_keys " +
str(errors.unexpected_keys)
)
if utils.has_parameters(self.get_loss()):
self.get_loss().load_state_dict(state["loss"], strict=True)
del state["loss"]
except Exception:
......@@ -413,13 +417,21 @@ class Trainer(object):
"please ensure that the architectures match.".format(filename)
)
extra_state = state["extra_state"] if "extra_state" in state else None
self._optim_history = state["optimizer_history"] if "optimizer_history" in state else None
if ema_state is not None and self.ema is not None and not self.args.load_from_ema:
self._optim_history = (
state["optimizer_history"] if "optimizer_history" in state else None
)
if (
ema_state is not None
and self.ema is not None
and not self.args.load_from_ema
):
logger.info(f"Loading EMA state...")
self.ema.load_state_dict(ema_state)
elif self.ema is not None:
logger.info(f"Cannot find EMA state in checkpoint, load model weight to ema directly")
logger.info(
f"Cannot find EMA state in checkpoint, load model weight to ema directly"
)
self.ema = ExponentialMovingAverage(self._model, decay=self.ema.decay)
if last_optim_state is not None and not reset_optimizer:
......@@ -437,7 +449,7 @@ class Trainer(object):
if not reset_lr_scheduler:
self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"])
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
self.set_num_updates(last_optim["num_updates"])
......@@ -452,7 +464,10 @@ class Trainer(object):
# self.lr_step(epoch)
if itr_state.get("version", 1) >= 2 and itr_state["iterations_in_epoch"] == 0:
if (
itr_state.get("version", 1) >= 2
and itr_state["iterations_in_epoch"] == 0
):
# reset meters at start of epoch
reset_meters = True
......@@ -511,10 +526,12 @@ class Trainer(object):
def init_total_train_steps(self, epoch_itr):
if self.args.max_epoch > 0:
self._total_train_steps = (len(epoch_itr) + 1) // self.args.update_freq[0] * self.args.max_epoch
self._total_train_steps = (
(len(epoch_itr) + 1) // self.args.update_freq[0] * self.args.max_epoch
)
else:
self._total_train_steps = self.args.max_update
def get_valid_iterator(
self,
subset,
......@@ -589,7 +606,9 @@ class Trainer(object):
try:
with maybe_no_sync():
# use different seed for different rank in training, otherwise the dropout will be the same in different workers.
with utils.torch_seed(self.args.seed, self.get_num_updates(), self.data_parallel_rank):
with utils.torch_seed(
self.args.seed, self.get_num_updates(), self.data_parallel_rank
):
# forward and backward
loss, sample_size_i, logging_output = self.task.train_step(
sample=sample,
......@@ -601,7 +620,9 @@ class Trainer(object):
)
del loss
if self.args.per_sample_clip_norm > 0:
self.optimizer.per_sample_clip_grad_norm(self.args.per_sample_clip_norm)
self.optimizer.per_sample_clip_grad_norm(
self.args.per_sample_clip_norm
)
logging_outputs.append(logging_output)
sample_size += sample_size_i
......@@ -647,7 +668,12 @@ class Trainer(object):
ooms,
total_train_time,
) = self._aggregate_logging_outputs(
logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch, is_train=True,
logging_outputs,
sample_size,
ooms,
train_time,
ignore=is_dummy_batch,
is_train=True,
)
self._cumulative_training_time = (
total_train_time / self.data_parallel_world_size
......@@ -670,11 +696,7 @@ class Trainer(object):
# (Debugging note: Some optimizers perform this scaling on the
# fly, so inspecting model.parameters() or optimizer.params may
# still show the original, unscaled gradients.)
numer = (
self.data_parallel_world_size
if self._sync_stats()
else 1
)
numer = self.data_parallel_world_size if self._sync_stats() else 1
self.optimizer.multiply_grads(numer / (sample_size or 1.0))
# Note: (sample_size or 1.0) handles the case of a zero gradient, in a
......@@ -695,7 +717,9 @@ class Trainer(object):
with utils.torch_seed(self.args.seed, self.get_num_updates(), -1):
# take an optimization step
self.task.optimizer_step(
self.optimizer, model=self.model, update_num=self.get_num_updates()
self.optimizer,
model=self.model,
update_num=self.get_num_updates(),
)
if self.ema is not None:
with torch.autograd.profiler.record_function("ema"):
......@@ -719,7 +743,9 @@ class Trainer(object):
raise
except OverflowError as e:
overflow = True
logger.info(f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}")
logger.info(
f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}"
)
grad_norm = torch.tensor(0.0).cuda()
self.zero_grad()
except RuntimeError as e:
......@@ -737,13 +763,13 @@ class Trainer(object):
gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
torch.cuda.reset_peak_memory_stats()
gb_free = self.cuda_env.total_memory_in_GB - gb_used
metrics.log_scalar(
"gb_free", gb_free, priority=1500, round=1, weight=0
)
metrics.log_scalar("gb_free", gb_free, priority=1500, round=1, weight=0)
# log stats
logging_output = self._reduce_and_log_stats(
logging_outputs, sample_size, grad_norm,
logging_outputs,
sample_size,
grad_norm,
)
# clear CUDA cache to reduce memory fragmentation
......@@ -865,9 +891,7 @@ class Trainer(object):
metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)
def clip_grad_norm(self, clip_norm):
return self.optimizer.clip_grad_norm(
clip_norm
)
return self.optimizer.clip_grad_norm(clip_norm)
def cumulative_training_time(self):
if self._cumulative_training_time is None:
......@@ -908,7 +932,7 @@ class Trainer(object):
return t.to(dtype=torch.bfloat16)
return t
# Please manually convert data type by yourself.
# Please manually convert data type by yourself.
# if self.args.fp16:
# sample = utils.apply_to_sample(apply_half, sample)
......@@ -942,7 +966,9 @@ class Trainer(object):
ignore=False,
is_train=False,
):
if self.task.__class__.logging_outputs_can_be_summed(self.get_loss(), is_train=is_train):
if self.task.__class__.logging_outputs_can_be_summed(
self.get_loss(), is_train=is_train
):
return self._fast_stat_sync_sum(
logging_outputs, *extra_stats_to_sum, ignore=ignore
)
......@@ -978,7 +1004,10 @@ class Trainer(object):
return logging_outputs, extra_stats_to_sum
def _fast_stat_sync_sum(
self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False,
self,
logging_outputs: List[Dict[str, Any]],
*extra_stats_to_sum,
ignore=False,
):
"""
Sync logging outputs across workers. fast_stat_sync_sum is
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment