Unverified Commit 86d148f9 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Optimize calls to .cpu() during checkpointing (#363)



* Optimize calls to .cpu() during checkpointing
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes for ONNX
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a0e1cf99
......@@ -87,13 +87,7 @@ def get_amax_reduce_handle_fwd() -> Union[bool, None]:
def get_global_fp8_buffer() -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 buffer."""
buffer = {}
# Map all tensors to CPU.
for k, v in _global_fp8_buffer.items():
buffer[k] = [tensor.cpu() for tensor in v]
return buffer
return _global_fp8_buffer
def set_global_fp8_buffer(buffer: Dict[str, List[torch.Tensor]]) -> None:
......
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""Base modules and utilities for TransformerEngine PyTorch API"""
import io
import os
import pickle
import warnings
......@@ -11,12 +12,12 @@ from typing import Generator, Union, Optional, Tuple, Dict, Any, List
from functools import partial
from contextlib import contextmanager
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import transformer_engine_extensions as tex
from ..export import is_in_onnx_export_mode
from ..fp8 import (
is_fp8_enabled,
is_fp8_calibration,
......@@ -349,12 +350,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if fp8_checkpoint:
state = {}
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale.cpu()
state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv.cpu()
state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history.cpu()
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale.cpu()
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv.cpu()
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history.cpu()
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv
state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
state["global_fp8_buffer"] = get_global_fp8_buffer()
# Store other pickelable values.
......@@ -364,10 +365,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
extra[k] = v
state["extra_fp8_variables"] = extra
state_serialized = pickle.dumps(state)
state_tensor = torch.tensor(np.frombuffer(state_serialized, dtype=np.uint8))
if is_in_onnx_export_mode():
state_serialized = torch.frombuffer(pickle.dumps(state), dtype=torch.uint8)
else:
state_serialized = io.BytesIO()
torch.save(state, state_serialized)
return state_tensor
return state_serialized
def set_extra_state(self, state: torch.Tensor) -> None:
"""Load previous state."""
......@@ -409,8 +413,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if isinstance(state, torch.Tensor):
state = pickle.loads(state.detach().cpu().numpy().tobytes())
if state is None:
return
elif isinstance(state, io.BytesIO):
state.seek(0)
state = torch.load(state, map_location='cuda')
if state is None:
return
# Restore global FP8 buffer states.
set_global_fp8_buffer(state["global_fp8_buffer"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment