Unverified Commit c149c145 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix bugs for full activation recompute in FP8 (#24)



* Fix bugs for full activation recompute in FP8
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Ensure identical numerics in recomputation for pipeline parallelism
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* expose checkpoint API and add docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* complete checkpointing docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 9f7b0255
...@@ -27,3 +27,5 @@ Functions ...@@ -27,3 +27,5 @@ Functions
--------- ---------
.. autofunction:: transformer_engine.pytorch.fp8_autocast .. autofunction:: transformer_engine.pytorch.fp8_autocast
.. autofunction:: transformer_engine.pytorch.checkpoint
...@@ -9,3 +9,4 @@ from .module import LayerNormMLP ...@@ -9,3 +9,4 @@ from .module import LayerNormMLP
from .module import LayerNorm from .module import LayerNorm
from .transformer import TransformerLayer from .transformer import TransformerLayer
from .fp8 import fp8_autocast from .fp8 import fp8_autocast
from .distributed import checkpoint
...@@ -3,13 +3,16 @@ ...@@ -3,13 +3,16 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Methods needed for distributed training (DP/TP).""" """Methods needed for distributed training (DP/TP)."""
from typing import Union, Optional, Callable, Tuple from contextlib import contextmanager
from typing import Any, Dict, Union, Optional, Callable, Tuple
import torch import torch
from torch.cuda import _lazy_call from torch.cuda import _lazy_call
from torch.utils.checkpoint import detach_variable from torch.utils.checkpoint import detach_variable
from .utils import safely_set_viewless_tensor_data from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type from .constants import dist_group_type
from .fp8 import is_fp8_enabled
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"tensor_model_parallel": False, "tensor_model_parallel": False,
...@@ -17,6 +20,9 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { ...@@ -17,6 +20,9 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"partition_stride": 1, "partition_stride": 1,
} }
_FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False
def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) -> None: def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) -> None:
"""Sets the random number generator state of the current GPU. """Sets the random number generator state of the current GPU.
...@@ -125,6 +131,38 @@ def gather_split_1d_tensor( ...@@ -125,6 +131,38 @@ def gather_split_1d_tensor(
return gathered return gathered
@contextmanager
def activation_recompute_forward(
activation_recompute: bool = False,
recompute_phase: bool = False,
) -> None:
"""Context manager used to control the forward runtime behavior when executed
under the `CheckpointFunction` function. For running FP8, the forward pass will
run without storing intermediate activations. Instead, the forward pass saves
the inputs tuple and the calling function. In the backwards pass, these are
retrieved, and the forward pass is computed again while tracking the intermediate
activations, followed by calculation of gradients using these values.
"""
global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
try:
_FP8_ACTIVATION_RECOMPUTE_ENABLED = activation_recompute and is_fp8_enabled()
_FP8_ACTIVATION_RECOMPUTE_PHASE = recompute_phase
yield
finally:
_FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False
def is_fp8_activation_recompute_enabled() -> bool:
"""Return global boolean"""
return _FP8_ACTIVATION_RECOMPUTE_ENABLED
def in_fp8_activation_recompute_phase() -> bool:
"""Return global boolean"""
return _FP8_ACTIVATION_RECOMPUTE_PHASE
class CheckpointFunction(torch.autograd.Function): class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with """This function is adapted from torch.utils.checkpoint with
two main changes: two main changes:
...@@ -140,6 +178,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -140,6 +178,7 @@ class CheckpointFunction(torch.autograd.Function):
distribute_saved_activations: bool, distribute_saved_activations: bool,
get_cuda_rng_tracker: Callable, get_cuda_rng_tracker: Callable,
tp_group: dist_group_type, tp_group: dist_group_type,
kwargs: Dict[str, Any],
*args: Tuple[torch.Tensor, ...], *args: Tuple[torch.Tensor, ...],
) -> Tuple[torch.Tensor, ...]: ) -> Tuple[torch.Tensor, ...]:
ctx.run_function = run_function ctx.run_function = run_function
...@@ -151,7 +190,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -151,7 +190,10 @@ class CheckpointFunction(torch.autograd.Function):
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
with torch.no_grad(): with torch.no_grad():
outputs = run_function(*args) with activation_recompute_forward(
activation_recompute=True, recompute_phase=False
):
outputs = run_function(*args, **kwargs)
# Divide hidden states across model parallel group and only keep # Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank. # the chunk corresponding to the current rank.
...@@ -168,6 +210,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -168,6 +210,7 @@ class CheckpointFunction(torch.autograd.Function):
ctx.save_for_backward(*args) ctx.save_for_backward(*args)
ctx.get_cuda_rng_tracker = get_cuda_rng_tracker ctx.get_cuda_rng_tracker = get_cuda_rng_tracker
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.kwargs = kwargs
return outputs return outputs
...@@ -204,7 +247,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -204,7 +247,10 @@ class CheckpointFunction(torch.autograd.Function):
# Compute the forward pass. # Compute the forward pass.
detached_inputs = detach_variable(inputs) detached_inputs = detach_variable(inputs)
with torch.enable_grad(): with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs) with activation_recompute_forward(
activation_recompute=True, recompute_phase=True
):
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
# Set the states back to what it was at the start of this function. # Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state) torch.set_rng_state(bwd_cpu_rng_state)
...@@ -218,7 +264,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -218,7 +264,7 @@ class CheckpointFunction(torch.autograd.Function):
inp.grad if isinstance(inp, torch.Tensor) else inp inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs for inp in detached_inputs
) )
return (None, None, None, None) + grads return (None, None, None, None, None) + grads
def checkpoint( def checkpoint(
...@@ -227,11 +273,54 @@ def checkpoint( ...@@ -227,11 +273,54 @@ def checkpoint(
get_cuda_rng_tracker: Callable, get_cuda_rng_tracker: Callable,
tp_group: dist_group_type, tp_group: dist_group_type,
*args: Tuple[torch.Tensor, ...], *args: Tuple[torch.Tensor, ...],
**kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, ...]: ) -> Tuple[torch.Tensor, ...]:
"""Checkpoint a model or part of the model. """
This has been directly copied from torch.utils.checkpoint.""" Checkpoint a part of the model by trading compute for memory. This function is based on
`torch.utils.checkpoint.checkpoint <https://pytorch.org/docs/stable/checkpoint.html>`_.
.. warning::
It is the user's responsibility to ensure identical behavior when calling
:attr:`function` from the forward and backward pass. If different output is
produced (e.g. due to global state), then the checkpointed version won't
be numerically equivalent.
.. warning::
The tuple :attr:`args` must contain only tensors (or :attr:`None`) in order to comply with
PyTorch's :attr:`save_for_backward` method. :attr:`function` must be callable to produce
valid outputs with the inputs :attr:`args` and :attr:`kwargs`.
Parameters
----------
function: Callable
whether or not to enable fp8
distribute_saved_activations: bool
if set to `True`, the first tensor argument is distributed across the
specified tensor parallel group (`tp_group`) before saving it for the
backward pass.
get_cuda_rng_tracker: `Callable`
python function with the functionality to retrieve a state via
:attr:`state = get_cuda_rng_tracker().get_states()` and to reset the state via
:attr:`get_cuda_rng_tracker().set_states(state)`. This is used to ensure any
extra cuda rng state or general global state can be reproduced across the 2
forward phases; original and recompute.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
args : tuple
tuple of torch tensors for inputs to :attr:`function`.
kwargs : dict
dictionary of string keys for keyword arguments to :attr:`function`.
"""
return CheckpointFunction.apply( return CheckpointFunction.apply(
function, distribute_saved_activations, get_cuda_rng_tracker, tp_group, *args function,
distribute_saved_activations,
get_cuda_rng_tracker,
tp_group,
kwargs,
*args,
) )
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""FP8 utilies for TransformerEngine""" """FP8 utilies for TransformerEngine"""
from contextlib import contextmanager from contextlib import contextmanager
from collections import deque
from typing import Callable, List, Optional, Dict, Any, Tuple, Union from typing import Callable, List, Optional, Dict, Any, Tuple, Union
import torch import torch
...@@ -20,6 +21,7 @@ _FP8_AUTOCAST_COUNTER = 0 ...@@ -20,6 +21,7 @@ _FP8_AUTOCAST_COUNTER = 0
_FP8_CURRENT_CONTEXT_ID = 0 _FP8_CURRENT_CONTEXT_ID = 0
_FP8_AUTOCAST_DEPTH = 0 _FP8_AUTOCAST_DEPTH = 0
_global_fp8_buffer = {} _global_fp8_buffer = {}
_fp8_tensors_recompute_buffer = []
_amax_forward_global_reduce_func = None _amax_forward_global_reduce_func = None
_buffer_delete_key_fwd = None _buffer_delete_key_fwd = None
_buffer_delete_key_bwd = None _buffer_delete_key_bwd = None
...@@ -93,6 +95,59 @@ def add_amax_to_global_buffer(fp8_meta: Dict[str, Any], forward: bool = True) -> ...@@ -93,6 +95,59 @@ def add_amax_to_global_buffer(fp8_meta: Dict[str, Any], forward: bool = True) ->
fp8_meta[buffer_position_key] = len(_global_fp8_buffer[buffer_key]) - 1 fp8_meta[buffer_position_key] = len(_global_fp8_buffer[buffer_key]) - 1
def copy_forward_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> None:
"""Copy the scaling factors and amaxes for recompute forward phase
to ensure both forward steps are numerically same.
"""
global _fp8_tensors_recompute_buffer
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
to_copy = (
fp8_meta["scaling_fwd"].amax_history.clone(),
fp8_meta["scaling_fwd"].scale.clone(),
fp8_meta["scaling_fwd"].scale_inv.clone(),
)
if buffer_position_key in fp8_meta:
_fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy)
else:
if len(_fp8_tensors_recompute_buffer) == 0:
_fp8_tensors_recompute_buffer = [deque()]
else:
_fp8_tensors_recompute_buffer.append(deque())
_fp8_tensors_recompute_buffer[-1].append(to_copy)
fp8_meta[buffer_position_key] = len(_fp8_tensors_recompute_buffer) - 1
def get_old_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> None:
"""Switch to the copied scaling factors and amaxes from phase
1 forward for indentical numerical outputs.
"""
# Store updated amaxes and scales from phase 1 post forward.
fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history
fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale
fp8_meta["updated_scale_inv_fwd"] = fp8_meta["scaling_fwd"].scale_inv
# Retrieve stashed amaxes and scales from phase 1 pre forward.
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
stashed_fp8_meta = _fp8_tensors_recompute_buffer[
fp8_meta[buffer_position_key]
].popleft()
# Replace amaxes and scales with stashed values for phase 2 forward
fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0]
fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1]
fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2]
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run."""
fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"]
fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"]
fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]
def copy_amax_from_global_buffer( def copy_amax_from_global_buffer(
fp8_meta: Dict[str, Any], forward: bool = True fp8_meta: Dict[str, Any], forward: bool = True
) -> None: ) -> None:
......
...@@ -33,6 +33,9 @@ from .fp8 import ( ...@@ -33,6 +33,9 @@ from .fp8 import (
set_global_fp8_buffer, set_global_fp8_buffer,
set_amax_buffer_key_deletion, set_amax_buffer_key_deletion,
delete_key_from_amax_buffer, delete_key_from_amax_buffer,
copy_forward_fp8_meta_tensors_for_recompute,
get_old_fp8_meta_tensors_for_recompute,
restore_fp8_meta_tensors,
) )
from .jit import ( from .jit import (
bias_gelu_fused, bias_gelu_fused,
...@@ -53,6 +56,8 @@ from .distributed import ( ...@@ -53,6 +56,8 @@ from .distributed import (
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
gather_along_last_dim, gather_along_last_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
) )
from .cpp_extensions import ( from .cpp_extensions import (
fp8_gemm, fp8_gemm,
...@@ -291,6 +296,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -291,6 +296,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def pre_forward(self, inp: torch.Tensor, num_gemms: int = 1) -> None: def pre_forward(self, inp: torch.Tensor, num_gemms: int = 1) -> None:
"""Checks and prep for FWD.""" """Checks and prep for FWD."""
# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
return
assert inp.is_cuda, "TransformerEngine needs CUDA." assert inp.is_cuda, "TransformerEngine needs CUDA."
if self.tp_size > 1: if self.tp_size > 1:
...@@ -306,7 +316,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -306,7 +316,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
amax_and_scale_update(self.fp8_meta, True) amax_and_scale_update(self.fp8_meta, True)
set_amax_buffer_key_deletion(self.fp8_meta, forward=True) set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
if self.fp8 and torch.is_grad_enabled() and self.training: if self.fp8 and self.training:
self.fp8_meta["first_module"] = is_first_fp8_module() self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]: if self.fp8_meta["first_module"]:
...@@ -320,6 +330,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -320,6 +330,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
else: else:
self.fp8_meta["update_amax_and_scale_fwd"] = False self.fp8_meta["update_amax_and_scale_fwd"] = False
# Activation recomputation is used and this is the first forward phase.
if (
self.fp8
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
def post_forward(self) -> None: def post_forward(self) -> None:
"""This is needed because there isn't a way for a module to know """This is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful if it's the last FP8 module in the forward autocast. It is useful
...@@ -327,7 +345,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -327,7 +345,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
just in case. The autocast exit will pick up the most recent. just in case. The autocast exit will pick up the most recent.
""" """
if self.fp8 and torch.is_grad_enabled() and self.training: if self.fp8 and in_fp8_activation_recompute_phase():
restore_fp8_meta_tensors(self.fp8_meta)
return
if self.fp8 and self.training:
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial( reduce_func = partial(
global_amax_reduction, global_amax_reduction,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment