Unverified Commit 26ecb2f1 authored by Peter St. John's avatar Peter St. John Committed by GitHub
Browse files

Don't serialize a None tensor if not using fp8 (#1749)


Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 0c5e3a52
......@@ -592,7 +592,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
reset("scaling_fwd")
reset("scaling_bwd")
def get_extra_state(self) -> torch.Tensor:
def get_extra_state(self) -> Optional[torch.Tensor]:
"""Save before checkpointing."""
# This implementation is working around a few issues:
......@@ -626,25 +626,26 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Store FP8 state if needed
state = None
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
if fp8_checkpoint:
# Copy tensors to CPU and store
state = {}
state["recipe"] = self.fp8_meta["recipe"]
if state["recipe"].delayed():
state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)
# Store other pickelable values
extra = {}
for k, v in self.fp8_meta.items():
if k != "buffer_index_and_autocast_key" and isinstance(
v, (bool, int, float, str, tuple, list)
):
extra[k] = v
state["extra_fp8_variables"] = extra
if not fp8_checkpoint:
return None
# Copy tensors to CPU and store
state = {}
state["recipe"] = self.fp8_meta["recipe"]
if state["recipe"].delayed():
state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)
# Store other pickelable values
extra = {}
for k, v in self.fp8_meta.items():
if k != "buffer_index_and_autocast_key" and isinstance(
v, (bool, int, float, str, tuple, list)
):
extra[k] = v
state["extra_fp8_variables"] = extra
# Serialize state into byte tensor
torch.cuda.synchronize()
......@@ -652,7 +653,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
return state_serialized
def set_extra_state(self, state: torch.Tensor) -> None:
def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
"""Load previous state."""
if state is None:
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment