Unverified Commit c126396b authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Address steady memory increase and bloated checkpoints (#63)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f63b27e8
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""FP8 utilies for TransformerEngine""" """FP8 utilies for TransformerEngine"""
from contextlib import contextmanager from contextlib import contextmanager
from collections import deque from collections import deque
from typing import Callable, List, Optional, Dict, Any, Tuple, Union, Deque from typing import Callable, List, Optional, Dict, Any, Tuple, Union
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
...@@ -65,22 +65,6 @@ def set_global_fp8_buffer(buffer: Dict[str, List[torch.Tensor]]) -> None: ...@@ -65,22 +65,6 @@ def set_global_fp8_buffer(buffer: Dict[str, List[torch.Tensor]]) -> None:
_global_fp8_buffer = buffer _global_fp8_buffer = buffer
def get_global_fp8_recompute_buffer() -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 recompute buffer."""
return _fp8_tensors_recompute_buffer
def set_global_fp8_recompute_buffer(buffer: List[Deque[List[torch.Tensor]]]) -> None:
"""Sets global fp8 recompute buffer."""
global _fp8_tensors_recompute_buffer
# Map all tensors back to GPU.
for index, deck in enumerate(buffer):
buffer[index] = deque([[t.cuda() for t in tensors] for tensors in deck])
_fp8_tensors_recompute_buffer = buffer
def setup_amax_forward_global_reduce_func(f: Callable) -> None: def setup_amax_forward_global_reduce_func(f: Callable) -> None:
"""Sets up the function to call during autocast exit.""" """Sets up the function to call during autocast exit."""
global _amax_forward_global_reduce_func global _amax_forward_global_reduce_func
......
...@@ -35,8 +35,6 @@ from .fp8 import ( ...@@ -35,8 +35,6 @@ from .fp8 import (
amax_and_scale_update, amax_and_scale_update,
get_global_fp8_buffer, get_global_fp8_buffer,
set_global_fp8_buffer, set_global_fp8_buffer,
get_global_fp8_recompute_buffer,
set_global_fp8_recompute_buffer,
set_amax_buffer_key_deletion, set_amax_buffer_key_deletion,
delete_key_from_amax_buffer, delete_key_from_amax_buffer,
copy_forward_fp8_meta_tensors_for_recompute, copy_forward_fp8_meta_tensors_for_recompute,
...@@ -209,7 +207,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -209,7 +207,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
state["global_fp8_buffer"] = get_global_fp8_buffer() state["global_fp8_buffer"] = get_global_fp8_buffer()
state["global_fp8_recompute_buffer"] = get_global_fp8_recompute_buffer()
# Store other pickelable values. # Store other pickelable values.
extra = {} extra = {}
...@@ -269,11 +266,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -269,11 +266,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Restore global FP8 buffer states. # Restore global FP8 buffer states.
set_global_fp8_buffer(state["global_fp8_buffer"]) set_global_fp8_buffer(state["global_fp8_buffer"])
set_global_fp8_recompute_buffer(state["global_fp8_recompute_buffer"])
# Load extra items. # Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]
# Initialize before loading. # Initialize before loading.
self.init_fp8_meta_tensors() self.init_fp8_meta_tensors()
...@@ -452,6 +449,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -452,6 +449,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Activation recomputation is used and this is the first forward phase. # Activation recomputation is used and this is the first forward phase.
if ( if (
self.fp8 self.fp8
and self.training
and is_fp8_activation_recompute_enabled() and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase() and not in_fp8_activation_recompute_phase()
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment