Commit 21cb6b39 authored by Guolin Ke's avatar Guolin Ke
Browse files

warning about missing keys in loading model

parent 8da5eaaf
...@@ -27,6 +27,7 @@ from unicore.utils import tensor_tree_map ...@@ -27,6 +27,7 @@ from unicore.utils import tensor_tree_map
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ExponentialMovingAverage: class ExponentialMovingAverage:
""" """
Maintains moving averages of parameters with exponential decay Maintains moving averages of parameters with exponential decay
...@@ -164,7 +165,7 @@ class Trainer(object): ...@@ -164,7 +165,7 @@ class Trainer(object):
else: else:
self.cuda_env = None self.cuda_env = None
self.cuda_env_arr = None self.cuda_env_arr = None
# add ema # add ema
if args.ema_decay > 0 and self.data_parallel_rank == 0: if args.ema_decay > 0 and self.data_parallel_rank == 0:
self.ema = ExponentialMovingAverage(self._model, decay=args.ema_decay) self.ema = ExponentialMovingAverage(self._model, decay=args.ema_decay)
...@@ -207,9 +208,7 @@ class Trainer(object): ...@@ -207,9 +208,7 @@ class Trainer(object):
@property @property
def use_distributed_wrapper(self) -> bool: def use_distributed_wrapper(self) -> bool:
return ( return self.data_parallel_world_size > 1
self.data_parallel_world_size > 1
)
@property @property
def should_save_checkpoint_on_current_rank(self) -> bool: def should_save_checkpoint_on_current_rank(self) -> bool:
...@@ -224,10 +223,7 @@ class Trainer(object): ...@@ -224,10 +223,7 @@ class Trainer(object):
@property @property
def loss(self): def loss(self):
if self._wrapped_loss is None: if self._wrapped_loss is None:
if ( if utils.has_parameters(self._loss) and self.use_distributed_wrapper:
utils.has_parameters(self._loss)
and self.use_distributed_wrapper
):
self._wrapped_loss = models.DistributedUnicoreModel( self._wrapped_loss = models.DistributedUnicoreModel(
self.args, self.args,
self._loss, self._loss,
...@@ -281,7 +277,7 @@ class Trainer(object): ...@@ -281,7 +277,7 @@ class Trainer(object):
"please switch to FP32 which is likely to be faster" "please switch to FP32 which is likely to be faster"
) )
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params) self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
if self.args.allreduce_fp32_grad: if self.args.allreduce_fp32_grad:
assert self.args.ddp_backend == "no_c10d" assert self.args.ddp_backend == "no_c10d"
if self.args.per_sample_clip_norm > 0: if self.args.per_sample_clip_norm > 0:
...@@ -290,7 +286,7 @@ class Trainer(object): ...@@ -290,7 +286,7 @@ class Trainer(object):
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7: if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
logger.info("NOTE: your device may support faster training with --fp16") logger.info("NOTE: your device may support faster training with --fp16")
self._optimizer = optim.build_optimizer(self.args, params) self._optimizer = optim.build_optimizer(self.args, params)
# We should initialize the learning rate scheduler immediately after # We should initialize the learning rate scheduler immediately after
# building the optimizer, so that the initial learning rate is set. # building the optimizer, so that the initial learning rate is set.
self._lr_scheduler = lr_scheduler.build_lr_scheduler( self._lr_scheduler = lr_scheduler.build_lr_scheduler(
...@@ -305,8 +301,7 @@ class Trainer(object): ...@@ -305,8 +301,7 @@ class Trainer(object):
"args": self.args, "args": self.args,
"model": self.model.state_dict(), "model": self.model.state_dict(),
"loss": ( "loss": (
self.loss.state_dict() self.loss.state_dict() if utils.has_parameters(self.loss) else None
if utils.has_parameters(self.loss) else None
), ),
"optimizer_history": (self._optim_history or []) "optimizer_history": (self._optim_history or [])
+ [ + [
...@@ -321,7 +316,7 @@ class Trainer(object): ...@@ -321,7 +316,7 @@ class Trainer(object):
"extra_state": { "extra_state": {
"metrics": metrics.state_dict(), "metrics": metrics.state_dict(),
"previous_training_time": self.cumulative_training_time(), "previous_training_time": self.cumulative_training_time(),
} },
} }
if not self.args.no_save_optimizer_state: if not self.args.no_save_optimizer_state:
state_dict["last_optimizer_state"] = self.optimizer.state_dict() state_dict["last_optimizer_state"] = self.optimizer.state_dict()
...@@ -375,7 +370,7 @@ class Trainer(object): ...@@ -375,7 +370,7 @@ class Trainer(object):
state = None state = None
if is_master: if is_master:
state = checkpoint_utils.load_checkpoint_to_cpu( state = checkpoint_utils.load_checkpoint_to_cpu(
filename, filename,
) )
if is_distributed: if is_distributed:
logger.info("Broadcast checkpoint from rank_0") logger.info("Broadcast checkpoint from rank_0")
...@@ -392,19 +387,28 @@ class Trainer(object): ...@@ -392,19 +387,28 @@ class Trainer(object):
try: try:
if self.args.load_from_ema: if self.args.load_from_ema:
logger.info("loading ema state to model") logger.info("loading ema state to model")
self.model.load_state_dict( errors = self.model.load_state_dict(
ema_state["params"], strict=False, model_args=self.args ema_state["params"], strict=False, model_args=self.args
) )
else: else:
self.model.load_state_dict( errors = self.model.load_state_dict(
state["model"], strict=False, model_args=self.args state["model"], strict=False, model_args=self.args
) )
# save memory for later steps # save memory for later steps
del state["model"] del state["model"]
if utils.has_parameters(self.get_loss()):
self.get_loss().load_state_dict( if errors.missing_keys:
state["loss"], strict=True logger.warning(
"Error in loading model state, missing_keys " +
str(errors.missing_keys)
)
if errors.unexpected_keys:
logger.warning(
"Error in loading model state, unexpected_keys " +
str(errors.unexpected_keys)
) )
if utils.has_parameters(self.get_loss()):
self.get_loss().load_state_dict(state["loss"], strict=True)
del state["loss"] del state["loss"]
except Exception: except Exception:
...@@ -413,13 +417,21 @@ class Trainer(object): ...@@ -413,13 +417,21 @@ class Trainer(object):
"please ensure that the architectures match.".format(filename) "please ensure that the architectures match.".format(filename)
) )
extra_state = state["extra_state"] if "extra_state" in state else None extra_state = state["extra_state"] if "extra_state" in state else None
self._optim_history = state["optimizer_history"] if "optimizer_history" in state else None self._optim_history = (
state["optimizer_history"] if "optimizer_history" in state else None
if ema_state is not None and self.ema is not None and not self.args.load_from_ema: )
if (
ema_state is not None
and self.ema is not None
and not self.args.load_from_ema
):
logger.info(f"Loading EMA state...") logger.info(f"Loading EMA state...")
self.ema.load_state_dict(ema_state) self.ema.load_state_dict(ema_state)
elif self.ema is not None: elif self.ema is not None:
logger.info(f"Cannot find EMA state in checkpoint, load model weight to ema directly") logger.info(
f"Cannot find EMA state in checkpoint, load model weight to ema directly"
)
self.ema = ExponentialMovingAverage(self._model, decay=self.ema.decay) self.ema = ExponentialMovingAverage(self._model, decay=self.ema.decay)
if last_optim_state is not None and not reset_optimizer: if last_optim_state is not None and not reset_optimizer:
...@@ -437,7 +449,7 @@ class Trainer(object): ...@@ -437,7 +449,7 @@ class Trainer(object):
if not reset_lr_scheduler: if not reset_lr_scheduler:
self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"])
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
self.set_num_updates(last_optim["num_updates"]) self.set_num_updates(last_optim["num_updates"])
...@@ -452,7 +464,10 @@ class Trainer(object): ...@@ -452,7 +464,10 @@ class Trainer(object):
# self.lr_step(epoch) # self.lr_step(epoch)
if itr_state.get("version", 1) >= 2 and itr_state["iterations_in_epoch"] == 0: if (
itr_state.get("version", 1) >= 2
and itr_state["iterations_in_epoch"] == 0
):
# reset meters at start of epoch # reset meters at start of epoch
reset_meters = True reset_meters = True
...@@ -511,10 +526,12 @@ class Trainer(object): ...@@ -511,10 +526,12 @@ class Trainer(object):
def init_total_train_steps(self, epoch_itr): def init_total_train_steps(self, epoch_itr):
if self.args.max_epoch > 0: if self.args.max_epoch > 0:
self._total_train_steps = (len(epoch_itr) + 1) // self.args.update_freq[0] * self.args.max_epoch self._total_train_steps = (
(len(epoch_itr) + 1) // self.args.update_freq[0] * self.args.max_epoch
)
else: else:
self._total_train_steps = self.args.max_update self._total_train_steps = self.args.max_update
def get_valid_iterator( def get_valid_iterator(
self, self,
subset, subset,
...@@ -589,7 +606,9 @@ class Trainer(object): ...@@ -589,7 +606,9 @@ class Trainer(object):
try: try:
with maybe_no_sync(): with maybe_no_sync():
# use different seed for different rank in training, otherwise the dropout will be the same in different workers. # use different seed for different rank in training, otherwise the dropout will be the same in different workers.
with utils.torch_seed(self.args.seed, self.get_num_updates(), self.data_parallel_rank): with utils.torch_seed(
self.args.seed, self.get_num_updates(), self.data_parallel_rank
):
# forward and backward # forward and backward
loss, sample_size_i, logging_output = self.task.train_step( loss, sample_size_i, logging_output = self.task.train_step(
sample=sample, sample=sample,
...@@ -601,7 +620,9 @@ class Trainer(object): ...@@ -601,7 +620,9 @@ class Trainer(object):
) )
del loss del loss
if self.args.per_sample_clip_norm > 0: if self.args.per_sample_clip_norm > 0:
self.optimizer.per_sample_clip_grad_norm(self.args.per_sample_clip_norm) self.optimizer.per_sample_clip_grad_norm(
self.args.per_sample_clip_norm
)
logging_outputs.append(logging_output) logging_outputs.append(logging_output)
sample_size += sample_size_i sample_size += sample_size_i
...@@ -647,7 +668,12 @@ class Trainer(object): ...@@ -647,7 +668,12 @@ class Trainer(object):
ooms, ooms,
total_train_time, total_train_time,
) = self._aggregate_logging_outputs( ) = self._aggregate_logging_outputs(
logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch, is_train=True, logging_outputs,
sample_size,
ooms,
train_time,
ignore=is_dummy_batch,
is_train=True,
) )
self._cumulative_training_time = ( self._cumulative_training_time = (
total_train_time / self.data_parallel_world_size total_train_time / self.data_parallel_world_size
...@@ -670,11 +696,7 @@ class Trainer(object): ...@@ -670,11 +696,7 @@ class Trainer(object):
# (Debugging note: Some optimizers perform this scaling on the # (Debugging note: Some optimizers perform this scaling on the
# fly, so inspecting model.parameters() or optimizer.params may # fly, so inspecting model.parameters() or optimizer.params may
# still show the original, unscaled gradients.) # still show the original, unscaled gradients.)
numer = ( numer = self.data_parallel_world_size if self._sync_stats() else 1
self.data_parallel_world_size
if self._sync_stats()
else 1
)
self.optimizer.multiply_grads(numer / (sample_size or 1.0)) self.optimizer.multiply_grads(numer / (sample_size or 1.0))
# Note: (sample_size or 1.0) handles the case of a zero gradient, in a # Note: (sample_size or 1.0) handles the case of a zero gradient, in a
...@@ -695,7 +717,9 @@ class Trainer(object): ...@@ -695,7 +717,9 @@ class Trainer(object):
with utils.torch_seed(self.args.seed, self.get_num_updates(), -1): with utils.torch_seed(self.args.seed, self.get_num_updates(), -1):
# take an optimization step # take an optimization step
self.task.optimizer_step( self.task.optimizer_step(
self.optimizer, model=self.model, update_num=self.get_num_updates() self.optimizer,
model=self.model,
update_num=self.get_num_updates(),
) )
if self.ema is not None: if self.ema is not None:
with torch.autograd.profiler.record_function("ema"): with torch.autograd.profiler.record_function("ema"):
...@@ -719,7 +743,9 @@ class Trainer(object): ...@@ -719,7 +743,9 @@ class Trainer(object):
raise raise
except OverflowError as e: except OverflowError as e:
overflow = True overflow = True
logger.info(f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}") logger.info(
f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}"
)
grad_norm = torch.tensor(0.0).cuda() grad_norm = torch.tensor(0.0).cuda()
self.zero_grad() self.zero_grad()
except RuntimeError as e: except RuntimeError as e:
...@@ -737,13 +763,13 @@ class Trainer(object): ...@@ -737,13 +763,13 @@ class Trainer(object):
gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
gb_free = self.cuda_env.total_memory_in_GB - gb_used gb_free = self.cuda_env.total_memory_in_GB - gb_used
metrics.log_scalar( metrics.log_scalar("gb_free", gb_free, priority=1500, round=1, weight=0)
"gb_free", gb_free, priority=1500, round=1, weight=0
)
# log stats # log stats
logging_output = self._reduce_and_log_stats( logging_output = self._reduce_and_log_stats(
logging_outputs, sample_size, grad_norm, logging_outputs,
sample_size,
grad_norm,
) )
# clear CUDA cache to reduce memory fragmentation # clear CUDA cache to reduce memory fragmentation
...@@ -865,9 +891,7 @@ class Trainer(object): ...@@ -865,9 +891,7 @@ class Trainer(object):
metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200) metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)
def clip_grad_norm(self, clip_norm): def clip_grad_norm(self, clip_norm):
return self.optimizer.clip_grad_norm( return self.optimizer.clip_grad_norm(clip_norm)
clip_norm
)
def cumulative_training_time(self): def cumulative_training_time(self):
if self._cumulative_training_time is None: if self._cumulative_training_time is None:
...@@ -908,7 +932,7 @@ class Trainer(object): ...@@ -908,7 +932,7 @@ class Trainer(object):
return t.to(dtype=torch.bfloat16) return t.to(dtype=torch.bfloat16)
return t return t
# Please manually convert data type by yourself. # Please manually convert data type by yourself.
# if self.args.fp16: # if self.args.fp16:
# sample = utils.apply_to_sample(apply_half, sample) # sample = utils.apply_to_sample(apply_half, sample)
...@@ -942,7 +966,9 @@ class Trainer(object): ...@@ -942,7 +966,9 @@ class Trainer(object):
ignore=False, ignore=False,
is_train=False, is_train=False,
): ):
if self.task.__class__.logging_outputs_can_be_summed(self.get_loss(), is_train=is_train): if self.task.__class__.logging_outputs_can_be_summed(
self.get_loss(), is_train=is_train
):
return self._fast_stat_sync_sum( return self._fast_stat_sync_sum(
logging_outputs, *extra_stats_to_sum, ignore=ignore logging_outputs, *extra_stats_to_sum, ignore=ignore
) )
...@@ -978,7 +1004,10 @@ class Trainer(object): ...@@ -978,7 +1004,10 @@ class Trainer(object):
return logging_outputs, extra_stats_to_sum return logging_outputs, extra_stats_to_sum
def _fast_stat_sync_sum( def _fast_stat_sync_sum(
self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False, self,
logging_outputs: List[Dict[str, Any]],
*extra_stats_to_sum,
ignore=False,
): ):
""" """
Sync logging outputs across workers. fast_stat_sync_sum is Sync logging outputs across workers. fast_stat_sync_sum is
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment