Commit d06a8fb1 authored by Ajinkya Deogade's avatar Ajinkya Deogade Committed by Facebook GitHub Bot
Browse files

Trainer part 2: Create a separate TARGET for lightning trainer

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/561

This is the continuation from the part 1 D45912069 where we had not defined the TARGETS for the lightning trainer.
As the circular deps have been resolved, we can define the targets for `d2go/trainer/lightning` and move the other TARGETS inside `d2go/trainer`.

Reviewed By: tglik

Differential Revision: D46096373

fbshipit-source-id: 6efc13eb9ab343d11028fb238e6e3f0c64a03e09
parent 0cde431c
......@@ -33,6 +33,9 @@ else:
logger = logging.getLogger(__name__)
_CONVERT_FX_CALLBACK_ATTRIBUTE = "_convert_fx_callback"
_STATE_DICT_KEY = "state_dict"
_OLD_STATE_DICT_KEY = "model"
_OLD_EMA_KEY = "ema_state"
def _is_observer_key(state_dict_key):
......@@ -76,9 +79,6 @@ class QATCheckpointer(DetectionCheckpointer):
# assume file is from lightning; no one else seems to use the ".ckpt" extension
with PathManager.open(filename, "rb") as f:
data = self._torch_load(f)
# TODO: Remove once buck targets are modularized and directly use
# from d2go.runner.lightning_task import _convert_to_d2
# from d2go.runner.lightning_task import _convert_to_d2
_convert_to_d2(data)
return data
......@@ -691,13 +691,6 @@ def forward_custom_prepare_fx(root, sub_module_name, orig_ret):
return root, new_callback
# TODO: Remove once buck targets are modularized and directly use
# from d2go.runner.lightning_task import _convert_to_d2
_STATE_DICT_KEY = "state_dict"
_OLD_STATE_DICT_KEY = "model"
_OLD_EMA_KEY = "ema_state"
def _convert_to_d2(lightning_checkpoint: Dict[str, Any]) -> None:
prefix = "model" # based on DefaultTask.model.
old_keys = [x.lstrip("model.") for x in lightning_checkpoint[_STATE_DICT_KEY]]
......
......@@ -89,41 +89,6 @@ def _convert_to_lightning(d2_checkpoint: Dict[str, Any]) -> None:
d2_checkpoint["epoch"] = 0
def _convert_to_d2(lightning_checkpoint: Dict[str, Any]) -> None:
prefix = "model" # based on DefaultTask.model.
old_keys = [x.lstrip("model.") for x in lightning_checkpoint[_STATE_DICT_KEY]]
for key in old_keys:
if f"{prefix}.{key}" in lightning_checkpoint[_STATE_DICT_KEY]:
lightning_checkpoint[_STATE_DICT_KEY][key] = lightning_checkpoint[
_STATE_DICT_KEY
][f"{prefix}.{key}"]
del lightning_checkpoint[_STATE_DICT_KEY][f"{prefix}.{key}"]
for old, new in zip(
[_STATE_DICT_KEY, "global_step"], [_OLD_STATE_DICT_KEY, "iteration"]
):
lightning_checkpoint[new] = lightning_checkpoint[old]
del lightning_checkpoint[old]
for old, new in zip(
["optimizer_states", "lr_schedulers"], ["optimizer", "scheduler"]
):
if old not in lightning_checkpoint:
continue
lightning_checkpoint[new] = [lightning_checkpoint[old]]
del lightning_checkpoint[old]
for key in [
"epoch",
"pytorch-lightning_versio",
"callbacks",
"hparams_name",
"hyper_parameters",
]:
if key in lightning_checkpoint:
del lightning_checkpoint[key]
class ModelTag(str, Enum):
DEFAULT = "default"
EMA = "ema"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment