Unverified Commit 8c004241 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Store module extra state in tensor (#1335)



Store module extra state in tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 71ada55f
......@@ -588,20 +588,50 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing."""
state = None
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
# This implementation is working around a few issues:
#
# (1) PyTorch's "extra state" infrastructure might be able to
# support any picklable type, but they make no guarantees.
# We have experienced problems (e.g. in ONNX export) with
# non-tensor extra state.
# (2) PyTorch's checkpointing infrastructure does not remap
# devices for "extra state" like it does for "state dict".
# Thus, we want to avoid putting extra state on the GPU
# since it may be loaded on the wrong device.
# (3) The extra state consists of many small tensors. If we
# want to copy them all to CPU, then we need to avoid the
# overhead of many GPU-CPU memory transfers.
#
# See: https://github.com/NVIDIA/TransformerEngine/pull/351
# See: https://github.com/NVIDIA/TransformerEngine/pull/363
def to_cpu(src: torch.Tensor) -> torch.Tensor:
"""Helper function to make CPU copy of tensor
Memory transfer is asynchronous w.r.t. host, so GPU should
be synchronized before using result.
"""
dst = torch.empty_like(src, device="cpu")
dst.copy_(src, non_blocking=True)
return dst
# Store FP8 state if needed
state = None
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
if fp8_checkpoint:
# Copy tensors to CPU and store
state = {}
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv
state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
# Store other pickelable values.
state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
state["scale_inv_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale_inv)
state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)
state["scale_inv_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale_inv)
# Store other pickelable values
extra = {}
for k, v in self.fp8_meta.items():
if k != "buffer_index_and_autocast_key" and isinstance(
......@@ -610,12 +640,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
extra[k] = v
state["extra_fp8_variables"] = extra
if is_in_onnx_export_mode():
state_serialized = torch.frombuffer(pickle.dumps(state), dtype=torch.uint8)
else:
state_serialized = io.BytesIO()
torch.save(state, state_serialized)
# Serialize state into byte tensor
torch.cuda.synchronize()
state_serialized = bytearray(pickle.dumps(state))
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
return state_serialized
def set_extra_state(self, state: torch.Tensor) -> None:
......@@ -623,9 +651,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None:
return
# Load state
if isinstance(state, torch.Tensor):
# Default format: byte tensor with pickled data
state = pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO):
# Deprecated format with io.BytesIO
state.seek(0)
state = torch.load(state, map_location="cuda")
else:
......@@ -634,20 +665,32 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None:
return
# Load extra items.
# Load extra items
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]
# Initialize before loading.
# Initialize before loading
self.init_fp8_meta_tensors()
self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"])
self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"])
def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
"""Helper function to copy tensor from CPU
Memory transfer is asynchronous w.r.t. host, so GPU should
be synchronized before using result.
"""
dst.copy_(src, non_blocking=True)
# Load tensors
copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale)
copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history)
copy_tensor(state["scale_inv_fwd"], self.fp8_meta["scaling_fwd"].scale_inv)
copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale)
copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history)
copy_tensor(state["scale_inv_bwd"], self.fp8_meta["scaling_bwd"].scale_inv)
torch.cuda.synchronize()
def set_activation_dtype(self, inp: torch.Tensor) -> None:
"""Get activation data type for AMP."""
......
......@@ -514,7 +514,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
#
# (1) PyTorch's "extra state" infrastructure might be able to
# support any picklable type, but they make no guarantees.
# It seems that ONNX export experiences issues with
# We have experienced problems (e.g. in ONNX export) with
# non-tensor extra state.
# (2) PyTorch's checkpointing infrastructure does not remap
# devices for "extra state" like it does for "state dict".
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment