Unverified Commit 26ecb2f1 authored by Peter St. John's avatar Peter St. John Committed by GitHub
Browse files

Don't serialize a None tensor if not using fp8 (#1749)


Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 0c5e3a52
......@@ -592,7 +592,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
reset("scaling_fwd")
reset("scaling_bwd")
def get_extra_state(self) -> torch.Tensor:
def get_extra_state(self) -> Optional[torch.Tensor]:
"""Save before checkpointing."""
# This implementation is working around a few issues:
......@@ -626,7 +626,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Store FP8 state if needed
state = None
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
if fp8_checkpoint:
if not fp8_checkpoint:
return None
# Copy tensors to CPU and store
state = {}
......@@ -652,7 +653,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
return state_serialized
def set_extra_state(self, state: torch.Tensor) -> None:
def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
"""Load previous state."""
if state is None:
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment