Commit 6cff7737 authored by Tsahi Glik's avatar Tsahi Glik Committed by Facebook GitHub Bot
Browse files

Fix EMA model training with lightning

Summary:
Current implementation of d2go lightning default task fails when running a model training with EMA.
The error is :
```
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss.
```
The error is due the fact the d2go lightning task create a copy of the ema model for evaluation that does not included in the training, which raise the error that there are unused params.
This is solved by moving the copy creation to after training and to when evaluation starts.

Reviewed By: kazhang

Differential Revision: D33442690

fbshipit-source-id: e9e469e33811de0b4171a64293cc16a8157af08c
parent aeb15613
......@@ -86,10 +86,6 @@ def _convert_to_lightning(d2_checkpoint: Dict[str, Any]) -> None:
d2_checkpoint[new] = d2_checkpoint[old]
del d2_checkpoint[old]
if _OLD_EMA_KEY in d2_checkpoint:
for k, v in d2_checkpoint[_OLD_EMA_KEY].items():
d2_checkpoint[_STATE_DICT_KEY][f"model_ema.{k}"] = v
for old, new in zip(
["optimizer", "scheduler"], ["optimizer_states", "lr_schedulers"]
):
......@@ -129,7 +125,6 @@ class DefaultTask(pl.LightningModule):
decay=cfg.MODEL_EMA.DECAY,
device=cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE,
)
self.model_ema = deepcopy(self.model)
self.dataset_evaluators[ModelTag.EMA] = []
def _build_model(self) -> torch.nn.Module:
......@@ -419,8 +414,17 @@ class DefaultTask(pl.LightningModule):
def _on_evaluation_epoch_start(self):
if self.ema_state:
self.model_ema = deepcopy(self.model)
self.ema_state.apply_to(self.model_ema)
def on_validation_epoch_end(self):
if self.ema_state and hasattr(self, "model_ema"):
del self.model_ema
def on_test_epoch_end(self):
if self.ema_state and hasattr(self, "model_ema"):
del self.model_ema
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if self.ema_state:
checkpoint["model_ema"] = self.ema_state.state_dict()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment