"tests/python/pytorch/graphbolt/test_minibatch.py" did not exist on "e242de9cb57eded6ee745ba85ee8f3faa63e4567"
Commit 5c23bee8 authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

remove AC prefix from EMA to make it compatible with loading

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/578

# Problem:
d2go EMA uses `named_parameters()` to traverse model states and save EMA checkpoints, while using `state_dict()`  to save model checkpoints. This is a brittle practice because `named_parameters()` and `state_dict()` are calling two sets of python APIs and can return different things.
In the case of Activation Checkpointing (AC), we don't want AC wrapper to affect checkpoint names. Thus, `state_dict()` is overriden by Pytorch to remove prefix "_checkpoint_wrapped_module" from FQN. However, `named_parameters()` does not have that support, so prefix still exists. In the event of us changing AC wrapping strategy (very common for optimization), we will not be able to load the previous EMA state back to the model. And the same problem also happened with FSDP.

# Short-term hack:
This diff adds a short term hack to manually remove the AC prefix in EMA. We can expand `IGNORED_FQN_PREFIX` to support more use cases.

Reviewed By: wat3rBro

Differential Revision: D46815031

fbshipit-source-id: 29b6ea444ed2ef90b8741fccdcb2b62625933e7f
parent c0a84df5
...@@ -6,10 +6,13 @@ import copy ...@@ -6,10 +6,13 @@ import copy
import itertools import itertools
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import List from typing import Iterator, List
import torch import torch
from detectron2.engine.train_loop import HookBase from detectron2.engine.train_loop import HookBase
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_PREFIX,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -20,6 +23,9 @@ class EMAState(object): ...@@ -20,6 +23,9 @@ class EMAState(object):
self.include_frozen = include_frozen self.include_frozen = include_frozen
self.include_buffer = include_buffer self.include_buffer = include_buffer
self.state = {} self.state = {}
# HACK: This hack is needed to strip checkpoint wrapper prefix from fqns so it doesn't affect loading.
# TODO: Remove this hack by rewriting EMAState to use model.state_dict()
self.prefix_to_remove = [_CHECKPOINT_PREFIX]
@classmethod @classmethod
def FromModel(cls, model: torch.nn.Module, device: str = "", **kwargs): def FromModel(cls, model: torch.nn.Module, device: str = "", **kwargs):
...@@ -72,14 +78,19 @@ class EMAState(object): ...@@ -72,14 +78,19 @@ class EMAState(object):
self.state.clear() self.state.clear()
return self return self
def _get_model_parameter_iterator(self, model):
"""
Return iterator for model parameters. Remove frozen parameters if needed.
"""
for name, params in model.named_parameters():
if params.requires_grad or self.include_frozen:
yield name, params
def get_model_state_iterator(self, model): def get_model_state_iterator(self, model):
param_iter = model.named_parameters() param_iter = self._get_model_parameter_iterator(model)
if not self.include_frozen:
param_iter = iter((n, v) for n, v in param_iter if v.requires_grad)
if self.include_buffer: if self.include_buffer:
buffer_iter = model.named_buffers() param_iter = itertools.chain(param_iter, model.named_buffers())
return itertools.chain(param_iter, buffer_iter) return _remove_prefix(param_iter, self.prefix_to_remove)
return param_iter
def state_dict(self): def state_dict(self):
return self.state return self.state
...@@ -207,6 +218,16 @@ def _remove_ddp(model): ...@@ -207,6 +218,16 @@ def _remove_ddp(model):
return model return model
def _remove_prefix(named_iterator: Iterator, prefix_to_remove: List[str]) -> Iterator:
"""
Remove a list of prefix from a named_module iterator
"""
for name, params in named_iterator:
for prefix in prefix_to_remove:
name = name.replace(prefix, "")
yield name, params
def may_build_model_ema(cfg, model): def may_build_model_ema(cfg, model):
if not cfg.MODEL_EMA.ENABLED: if not cfg.MODEL_EMA.ENABLED:
return return
......
...@@ -112,8 +112,6 @@ class TestActivationCheckpointing(unittest.TestCase): ...@@ -112,8 +112,6 @@ class TestActivationCheckpointing(unittest.TestCase):
@tempdir @tempdir
def test_ac_runner(self, tmp_dir) -> None: def test_ac_runner(self, tmp_dir) -> None:
tmp_dir = "/tmp/test"
os.makedirs(tmp_dir, exist_ok=True)
ds_name = create_local_dataset(tmp_dir, 5, 10, 10) ds_name = create_local_dataset(tmp_dir, 5, 10, 10)
runner = Detectron2GoRunner() runner = Detectron2GoRunner()
cfg = _get_cfg(runner, tmp_dir, ds_name) cfg = _get_cfg(runner, tmp_dir, ds_name)
...@@ -124,3 +122,10 @@ class TestActivationCheckpointing(unittest.TestCase): ...@@ -124,3 +122,10 @@ class TestActivationCheckpointing(unittest.TestCase):
model = runner.build_model(cfg) model = runner.build_model(cfg)
runner.do_train(cfg, model, resume=False) runner.do_train(cfg, model, resume=False)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "model_0000002.pth"))) self.assertTrue(os.path.exists(os.path.join(tmp_dir, "model_0000002.pth")))
# resume training onto a non-AC-wrapped model
cfg.MODEL.MODELING_HOOKS = []
cfg.SOLVER.MAX_ITER = 6
model = runner.build_model(cfg)
runner.do_train(cfg, model, resume=True)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "model_0000005.pth")))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment