Unverified Commit c126396b authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Address steady memory increase and bloated checkpoints (#63)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f63b27e8
......@@ -5,7 +5,7 @@
"""FP8 utilies for TransformerEngine"""
from contextlib import contextmanager
from collections import deque
from typing import Callable, List, Optional, Dict, Any, Tuple, Union, Deque
from typing import Callable, List, Optional, Dict, Any, Tuple, Union
import torch
import transformer_engine_extensions as tex
......@@ -65,22 +65,6 @@ def set_global_fp8_buffer(buffer: Dict[str, List[torch.Tensor]]) -> None:
_global_fp8_buffer = buffer
def get_global_fp8_recompute_buffer() -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 recompute buffer."""
return _fp8_tensors_recompute_buffer
def set_global_fp8_recompute_buffer(buffer: List[Deque[List[torch.Tensor]]]) -> None:
"""Sets global fp8 recompute buffer."""
global _fp8_tensors_recompute_buffer
# Map all tensors back to GPU.
for index, deck in enumerate(buffer):
buffer[index] = deque([[t.cuda() for t in tensors] for tensors in deck])
_fp8_tensors_recompute_buffer = buffer
def setup_amax_forward_global_reduce_func(f: Callable) -> None:
"""Sets up the function to call during autocast exit."""
global _amax_forward_global_reduce_func
......
......@@ -35,8 +35,6 @@ from .fp8 import (
amax_and_scale_update,
get_global_fp8_buffer,
set_global_fp8_buffer,
get_global_fp8_recompute_buffer,
set_global_fp8_recompute_buffer,
set_amax_buffer_key_deletion,
delete_key_from_amax_buffer,
copy_forward_fp8_meta_tensors_for_recompute,
......@@ -209,7 +207,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
state["global_fp8_buffer"] = get_global_fp8_buffer()
state["global_fp8_recompute_buffer"] = get_global_fp8_recompute_buffer()
# Store other pickelable values.
extra = {}
......@@ -269,11 +266,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Restore global FP8 buffer states.
set_global_fp8_buffer(state["global_fp8_buffer"])
set_global_fp8_recompute_buffer(state["global_fp8_recompute_buffer"])
# Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]
# Initialize before loading.
self.init_fp8_meta_tensors()
......@@ -452,6 +449,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Activation recomputation is used and this is the first forward phase.
if (
self.fp8
and self.training
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment