Unverified Commit 0efc7daf authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Fix backward compatibility for checkpoint loading (#1868)



Fix for loading old ckpt formats
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent aedd7e10
......@@ -820,6 +820,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def set_extra_state(self, state: torch.Tensor) -> None:
"""Load previous state."""
# Maintain backwards compatibility with older checkpoints.
if state is None:
return
# Load state
if isinstance(state, torch.Tensor):
# No FP8 is indicated by an empty tensor we don't need to unpickle.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment