Commit c5bf9222 authored by Licheng Yu's avatar Licheng Yu Committed by Facebook GitHub Bot
Browse files

missing keys in _convert_to_d2

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/462

Fix errors in `_convert_to_d2`. Sometimes the keys are missing, we don't need remove them.

{F860805441}

Reviewed By: newstzpz

Differential Revision: D42929485

fbshipit-source-id: 8584879df5a07cbe5a864b4f170eef3d5f34dd6c
parent 6940fa9c
...@@ -93,10 +93,11 @@ def _convert_to_d2(lightning_checkpoint: Dict[str, Any]) -> None: ...@@ -93,10 +93,11 @@ def _convert_to_d2(lightning_checkpoint: Dict[str, Any]) -> None:
prefix = "model" # based on DefaultTask.model. prefix = "model" # based on DefaultTask.model.
old_keys = [x.lstrip("model.") for x in lightning_checkpoint[_STATE_DICT_KEY]] old_keys = [x.lstrip("model.") for x in lightning_checkpoint[_STATE_DICT_KEY]]
for key in old_keys: for key in old_keys:
lightning_checkpoint[_STATE_DICT_KEY][key] = lightning_checkpoint[ if f"{prefix}.{key}" in lightning_checkpoint[_STATE_DICT_KEY]:
_STATE_DICT_KEY lightning_checkpoint[_STATE_DICT_KEY][key] = lightning_checkpoint[
][f"{prefix}.{key}"] _STATE_DICT_KEY
del lightning_checkpoint[_STATE_DICT_KEY][f"{prefix}.{key}"] ][f"{prefix}.{key}"]
del lightning_checkpoint[_STATE_DICT_KEY][f"{prefix}.{key}"]
for old, new in zip( for old, new in zip(
[_STATE_DICT_KEY, "global_step"], [_OLD_STATE_DICT_KEY, "iteration"] [_STATE_DICT_KEY, "global_step"], [_OLD_STATE_DICT_KEY, "iteration"]
...@@ -112,11 +113,15 @@ def _convert_to_d2(lightning_checkpoint: Dict[str, Any]) -> None: ...@@ -112,11 +113,15 @@ def _convert_to_d2(lightning_checkpoint: Dict[str, Any]) -> None:
lightning_checkpoint[new] = [lightning_checkpoint[old]] lightning_checkpoint[new] = [lightning_checkpoint[old]]
del lightning_checkpoint[old] del lightning_checkpoint[old]
del lightning_checkpoint["epoch"] for key in [
del lightning_checkpoint["pytorch-lightning_version"] "epoch",
del lightning_checkpoint["callbacks"] "pytorch-lightning_versio",
del lightning_checkpoint["hparams_name"] "callbacks",
del lightning_checkpoint["hyper_parameters"] "hparams_name",
"hyper_parameters",
]:
if key in lightning_checkpoint:
del lightning_checkpoint[key]
class ModelTag(str, Enum): class ModelTag(str, Enum):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment