"...src/deepstream-yaml/deepstream_config_yaml.cpp" did not exist on "77e9fc8cc58ea12322adddc53ccc84b2f389bb4d"
Unverified Commit 85928d08 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Store FP8 checkpointing data in CPU (#351)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c8175d9e
......@@ -87,7 +87,13 @@ def get_amax_reduce_handle_fwd() -> Union[bool, None]:
def get_global_fp8_buffer() -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 buffer."""
return _global_fp8_buffer
buffer = {}
# Map all tensors to CPU.
for k, v in _global_fp8_buffer.items():
buffer[k] = [tensor.cpu() for tensor in v]
return buffer
def set_global_fp8_buffer(buffer: Dict[str, List[torch.Tensor]]) -> None:
......
......@@ -349,12 +349,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if fp8_checkpoint:
state = {}
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv
state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale.cpu()
state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv.cpu()
state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history.cpu()
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale.cpu()
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv.cpu()
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history.cpu()
state["global_fp8_buffer"] = get_global_fp8_buffer()
# Store other pickelable values.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment