"git@developer.sourcefind.cn:change/sglang.git" did not exist on "120c3634efa52f733bef4b8290aff9a70ba65215"
Commit 9d9f438b authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

Move EMA weights to current device before training

Summary:
Currently we move EMA weights to expected device right after loading from checkpoint.
However, by the time on_load_checkpoint hook is called, current GPU device has not been assigned. This could lead to EMA weights on cuda:0 while the model is on cuda:1.
This diff move EMA weights to device in `on_pretrain_routine_end` instead.

Reviewed By: zhanghang1989

Differential Revision: D28429843

fbshipit-source-id: d864fb3687eb6958872300c5ec0af7ce90591f83
parent 5509a138
......@@ -41,6 +41,7 @@ from pytorch_lightning.utilities import rank_zero_only, rank_zero_info
_STATE_DICT_KEY = "state_dict"
_OLD_STATE_DICT_KEY = "model"
_OLD_EMA_KEY = "ema_state"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
......@@ -82,6 +83,10 @@ def _convert_to_lightning(d2_checkpoint: Dict[str, Any]) -> None:
d2_checkpoint[new] = d2_checkpoint[old]
del d2_checkpoint[old]
if _OLD_EMA_KEY in d2_checkpoint:
for k, v in d2_checkpoint[_OLD_EMA_KEY].items():
d2_checkpoint[_STATE_DICT_KEY][f"model_ema.{k}"] = v
for old, new in zip(
["optimizer", "scheduler"], ["optimizer_states", "lr_schedulers"]
):
......@@ -90,6 +95,10 @@ def _convert_to_lightning(d2_checkpoint: Dict[str, Any]) -> None:
d2_checkpoint[new] = [d2_checkpoint[old]]
del d2_checkpoint[old]
if _OLD_EMA_KEY in d2_checkpoint:
d2_checkpoint["model_ema"] = d2_checkpoint[_OLD_EMA_KEY]
del d2_checkpoint[_OLD_EMA_KEY]
d2_checkpoint["epoch"] = 0
......@@ -367,6 +376,8 @@ class DefaultTask(pl.LightningModule):
if self.cfg.MODEL_EMA.ENABLED:
if self.ema_state and self.ema_state.has_inited():
# ema_state could have been loaded from checkpoint
# move to the current CUDA device if not on CPU
self.ema_state.to(self.ema_state.device)
return
self.ema_state = EMAState.from_model(
self.model,
......@@ -419,11 +430,12 @@ class DefaultTask(pl.LightningModule):
"EMA is enabled but EMA state is not found in given checkpoint"
)
else:
self.ema_state = EMAState()
self.ema_state = EMAState(
decay=self.cfg.MODEL_EMA.DECAY,
device=self.cfg.MODEL_EMA.DEVICE or self.cfg.MODEL.DEVICE,
)
self.ema_state.load_state_dict(checkpointed_state["model_ema"])
if not self.ema_state.device:
# EMA state device not given, move to module device
self.ema_state.to(self.device)
rank_zero_info("Loaded EMA state from checkpoint.")
def prepare_for_quant(self) -> pl.LightningModule:
if hasattr(self.model, "prepare_for_quant"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment