Commit c5bf9222 authored by Licheng Yu's avatar Licheng Yu Committed by Facebook GitHub Bot
Browse files

missing keys in _convert_to_d2

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/462

Fix errors in `_convert_to_d2`. Sometimes the keys are missing, we don't need remove them.

{F860805441}

Reviewed By: newstzpz

Differential Revision: D42929485

fbshipit-source-id: 8584879df5a07cbe5a864b4f170eef3d5f34dd6c
parent 6940fa9c
......@@ -93,10 +93,11 @@ def _convert_to_d2(lightning_checkpoint: Dict[str, Any]) -> None:
prefix = "model" # based on DefaultTask.model.
old_keys = [x.lstrip("model.") for x in lightning_checkpoint[_STATE_DICT_KEY]]
for key in old_keys:
lightning_checkpoint[_STATE_DICT_KEY][key] = lightning_checkpoint[
_STATE_DICT_KEY
][f"{prefix}.{key}"]
del lightning_checkpoint[_STATE_DICT_KEY][f"{prefix}.{key}"]
if f"{prefix}.{key}" in lightning_checkpoint[_STATE_DICT_KEY]:
lightning_checkpoint[_STATE_DICT_KEY][key] = lightning_checkpoint[
_STATE_DICT_KEY
][f"{prefix}.{key}"]
del lightning_checkpoint[_STATE_DICT_KEY][f"{prefix}.{key}"]
for old, new in zip(
[_STATE_DICT_KEY, "global_step"], [_OLD_STATE_DICT_KEY, "iteration"]
......@@ -112,11 +113,15 @@ def _convert_to_d2(lightning_checkpoint: Dict[str, Any]) -> None:
lightning_checkpoint[new] = [lightning_checkpoint[old]]
del lightning_checkpoint[old]
del lightning_checkpoint["epoch"]
del lightning_checkpoint["pytorch-lightning_version"]
del lightning_checkpoint["callbacks"]
del lightning_checkpoint["hparams_name"]
del lightning_checkpoint["hyper_parameters"]
for key in [
"epoch",
"pytorch-lightning_versio",
"callbacks",
"hparams_name",
"hyper_parameters",
]:
if key in lightning_checkpoint:
del lightning_checkpoint[key]
class ModelTag(str, Enum):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment