Unverified Commit 86d148f9 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Optimize calls to .cpu() during checkpointing (#363)



* Optimize calls to .cpu() during checkpointing
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes for ONNX
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a0e1cf99
...@@ -87,13 +87,7 @@ def get_amax_reduce_handle_fwd() -> Union[bool, None]: ...@@ -87,13 +87,7 @@ def get_amax_reduce_handle_fwd() -> Union[bool, None]:
def get_global_fp8_buffer() -> Dict[str, List[torch.Tensor]]: def get_global_fp8_buffer() -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 buffer.""" """Returns global fp8 buffer."""
buffer = {} return _global_fp8_buffer
# Map all tensors to CPU.
for k, v in _global_fp8_buffer.items():
buffer[k] = [tensor.cpu() for tensor in v]
return buffer
def set_global_fp8_buffer(buffer: Dict[str, List[torch.Tensor]]) -> None: def set_global_fp8_buffer(buffer: Dict[str, List[torch.Tensor]]) -> None:
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Base modules and utilities for TransformerEngine PyTorch API""" """Base modules and utilities for TransformerEngine PyTorch API"""
import io
import os import os
import pickle import pickle
import warnings import warnings
...@@ -11,12 +12,12 @@ from typing import Generator, Union, Optional, Tuple, Dict, Any, List ...@@ -11,12 +12,12 @@ from typing import Generator, Union, Optional, Tuple, Dict, Any, List
from functools import partial from functools import partial
from contextlib import contextmanager from contextlib import contextmanager
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from ..export import is_in_onnx_export_mode
from ..fp8 import ( from ..fp8 import (
is_fp8_enabled, is_fp8_enabled,
is_fp8_calibration, is_fp8_calibration,
...@@ -349,12 +350,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -349,12 +350,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if fp8_checkpoint: if fp8_checkpoint:
state = {} state = {}
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale.cpu() state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv.cpu() state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv
state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history.cpu() state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale.cpu() state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv.cpu() state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history.cpu() state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
state["global_fp8_buffer"] = get_global_fp8_buffer() state["global_fp8_buffer"] = get_global_fp8_buffer()
# Store other pickelable values. # Store other pickelable values.
...@@ -364,10 +365,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -364,10 +365,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
extra[k] = v extra[k] = v
state["extra_fp8_variables"] = extra state["extra_fp8_variables"] = extra
state_serialized = pickle.dumps(state) if is_in_onnx_export_mode():
state_tensor = torch.tensor(np.frombuffer(state_serialized, dtype=np.uint8)) state_serialized = torch.frombuffer(pickle.dumps(state), dtype=torch.uint8)
else:
state_serialized = io.BytesIO()
torch.save(state, state_serialized)
return state_tensor return state_serialized
def set_extra_state(self, state: torch.Tensor) -> None: def set_extra_state(self, state: torch.Tensor) -> None:
"""Load previous state.""" """Load previous state."""
...@@ -409,6 +413,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -409,6 +413,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if isinstance(state, torch.Tensor): if isinstance(state, torch.Tensor):
state = pickle.loads(state.detach().cpu().numpy().tobytes()) state = pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO):
state.seek(0)
state = torch.load(state, map_location='cuda')
if state is None: if state is None:
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment