"vscode:/vscode.git/clone" did not exist on "ea55f1f52c489535f0d3b583c81529762c9cb5ea"
Unverified Commit f351191b authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Fix FP8 checkpointing for non forward execution cases (#323)



Bug fix for checkpointing
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 04822f40
...@@ -267,6 +267,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -267,6 +267,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8 = False self.fp8 = False
self.fp8_calibration = False self.fp8_calibration = False
self.fp8_meta = {} self.fp8_meta = {}
self.fp8_meta["fp8_checkpoint"] = False
self.fp8_meta["fp8_group"] = None self.fp8_meta["fp8_group"] = None
self.fp8_meta["recipe"] = get_default_fp8_recipe() self.fp8_meta["recipe"] = get_default_fp8_recipe()
self.fp8_meta_tensors_initialized = False self.fp8_meta_tensors_initialized = False
...@@ -341,7 +342,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -341,7 +342,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def get_extra_state(self) -> torch.Tensor: def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing.""" """Save before checkpointing."""
state = None state = None
if self.fp8 or self.fp8_calibration:
# Maintain backward compatibility.
fp8_checkpoint = "fp8_checkpoint" in self.fp8_meta and self.fp8_meta["fp8_checkpoint"]
fp8_checkpoint = fp8_checkpoint or self.fp8 or self.fp8_calibration
if fp8_checkpoint:
state = {} state = {}
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv
...@@ -513,6 +519,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -513,6 +519,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Initialize fp8 related metadata and tensors during fprop.""" """Initialize fp8 related metadata and tensors during fprop."""
self.fp8 = is_fp8_enabled() self.fp8 = is_fp8_enabled()
self.fp8_calibration = is_fp8_calibration() self.fp8_calibration = is_fp8_calibration()
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8 or self.fp8_calibration: if self.fp8 or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything. # FP8 init has already been run and recipe is the same, don't do anything.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment