Unverified Commit 9d0d0709 authored by Chenguo Lin's avatar Chenguo Lin Committed by GitHub
Browse files

EMA: fix `state_dict()` and `load_state_dict()` & add `cur_decay_value` (#2146)

* EMA: fix `state_dict()` & add `cur_decay_value`

* EMA: fix a bug in `load_state_dict()`

'float' object (`state_dict["power"]`) has no attribute 'get'.

* del train_unconditional_ort.py
parent c1971a53
...@@ -563,7 +563,7 @@ def main(args): ...@@ -563,7 +563,7 @@ def main(args):
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
if args.use_ema: if args.use_ema:
logs["ema_decay"] = ema_model.decay logs["ema_decay"] = ema_model.cur_decay_value
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
progress_bar.close() progress_bar.close()
......
...@@ -124,6 +124,7 @@ class EMAModel: ...@@ -124,6 +124,7 @@ class EMAModel:
self.inv_gamma = inv_gamma self.inv_gamma = inv_gamma
self.power = power self.power = power
self.optimization_step = 0 self.optimization_step = 0
self.cur_decay_value = None # set in `step()`
self.model_cls = model_cls self.model_cls = model_cls
self.model_config = model_config self.model_config = model_config
...@@ -194,6 +195,7 @@ class EMAModel: ...@@ -194,6 +195,7 @@ class EMAModel:
# Compute the decay factor for the exponential moving average. # Compute the decay factor for the exponential moving average.
decay = self.get_decay(self.optimization_step) decay = self.get_decay(self.optimization_step)
self.cur_decay_value = decay
one_minus_decay = 1 - decay one_minus_decay = 1 - decay
for s_param, param in zip(self.shadow_params, parameters): for s_param, param in zip(self.shadow_params, parameters):
...@@ -239,7 +241,7 @@ class EMAModel: ...@@ -239,7 +241,7 @@ class EMAModel:
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
return { return {
"decay": self.decay, "decay": self.decay,
"min_decay": self.decay, "min_decay": self.min_decay,
"optimization_step": self.optimization_step, "optimization_step": self.optimization_step,
"update_after_step": self.update_after_step, "update_after_step": self.update_after_step,
"use_ema_warmup": self.use_ema_warmup, "use_ema_warmup": self.use_ema_warmup,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment