Unverified Commit 8c004241 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Store module extra state in tensor (#1335)



Store module extra state in tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 71ada55f
...@@ -588,20 +588,50 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -588,20 +588,50 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def get_extra_state(self) -> torch.Tensor: def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing.""" """Save before checkpointing."""
state = None
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration # This implementation is working around a few issues:
#
# (1) PyTorch's "extra state" infrastructure might be able to
# support any picklable type, but they make no guarantees.
# We have experienced problems (e.g. in ONNX export) with
# non-tensor extra state.
# (2) PyTorch's checkpointing infrastructure does not remap
# devices for "extra state" like it does for "state dict".
# Thus, we want to avoid putting extra state on the GPU
# since it may be loaded on the wrong device.
# (3) The extra state consists of many small tensors. If we
# want to copy them all to CPU, then we need to avoid the
# overhead of many GPU-CPU memory transfers.
#
# See: https://github.com/NVIDIA/TransformerEngine/pull/351
# See: https://github.com/NVIDIA/TransformerEngine/pull/363
def to_cpu(src: torch.Tensor) -> torch.Tensor:
"""Helper function to make CPU copy of tensor
Memory transfer is asynchronous w.r.t. host, so GPU should
be synchronized before using result.
"""
dst = torch.empty_like(src, device="cpu")
dst.copy_(src, non_blocking=True)
return dst
# Store FP8 state if needed
state = None
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
if fp8_checkpoint: if fp8_checkpoint:
# Copy tensors to CPU and store
state = {} state = {}
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history state["scale_inv_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale_inv)
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history state["scale_inv_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale_inv)
# Store other pickelable values. # Store other pickelable values
extra = {} extra = {}
for k, v in self.fp8_meta.items(): for k, v in self.fp8_meta.items():
if k != "buffer_index_and_autocast_key" and isinstance( if k != "buffer_index_and_autocast_key" and isinstance(
...@@ -610,12 +640,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -610,12 +640,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
extra[k] = v extra[k] = v
state["extra_fp8_variables"] = extra state["extra_fp8_variables"] = extra
if is_in_onnx_export_mode(): # Serialize state into byte tensor
state_serialized = torch.frombuffer(pickle.dumps(state), dtype=torch.uint8) torch.cuda.synchronize()
else: state_serialized = bytearray(pickle.dumps(state))
state_serialized = io.BytesIO() state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
torch.save(state, state_serialized)
return state_serialized return state_serialized
def set_extra_state(self, state: torch.Tensor) -> None: def set_extra_state(self, state: torch.Tensor) -> None:
...@@ -623,9 +651,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -623,9 +651,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None: if state is None:
return return
# Load state
if isinstance(state, torch.Tensor): if isinstance(state, torch.Tensor):
# Default format: byte tensor with pickled data
state = pickle.loads(state.detach().cpu().numpy().tobytes()) state = pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO): elif isinstance(state, io.BytesIO):
# Deprecated format with io.BytesIO
state.seek(0) state.seek(0)
state = torch.load(state, map_location="cuda") state = torch.load(state, map_location="cuda")
else: else:
...@@ -634,20 +665,32 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -634,20 +665,32 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None: if state is None:
return return
# Load extra items. # Load extra items
self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta: if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"] del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]
# Initialize before loading. # Initialize before loading
self.init_fp8_meta_tensors() self.init_fp8_meta_tensors()
self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"]) def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"]) """Helper function to copy tensor from CPU
self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"]) Memory transfer is asynchronous w.r.t. host, so GPU should
self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"]) be synchronized before using result.
"""
dst.copy_(src, non_blocking=True)
# Load tensors
copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale)
copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history)
copy_tensor(state["scale_inv_fwd"], self.fp8_meta["scaling_fwd"].scale_inv)
copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale)
copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history)
copy_tensor(state["scale_inv_bwd"], self.fp8_meta["scaling_bwd"].scale_inv)
torch.cuda.synchronize()
def set_activation_dtype(self, inp: torch.Tensor) -> None: def set_activation_dtype(self, inp: torch.Tensor) -> None:
"""Get activation data type for AMP.""" """Get activation data type for AMP."""
......
...@@ -514,7 +514,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -514,7 +514,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# #
# (1) PyTorch's "extra state" infrastructure might be able to # (1) PyTorch's "extra state" infrastructure might be able to
# support any picklable type, but they make no guarantees. # support any picklable type, but they make no guarantees.
# It seems that ONNX export experiences issues with # We have experienced problems (e.g. in ONNX export) with
# non-tensor extra state. # non-tensor extra state.
# (2) PyTorch's checkpointing infrastructure does not remap # (2) PyTorch's checkpointing infrastructure does not remap
# devices for "extra state" like it does for "state dict". # devices for "extra state" like it does for "state dict".
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment