Unverified Commit 82bc797f authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Non-reentrant mode for activation recompute (#670)



* added non-reentrant mode support to TE checkpoint
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* updated get_cuda_rng_tracker kwarg to get_rng_state_tracker to remain consistent with other TE API
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* docstring cleanup
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added mechanism to disable bias_gelu_nvfusion in LayerNormMLP when checkpointing in non-reentrant mode
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* refactored checkpoint and recompute hook names to match PyTorch implementation
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* Fixed incorrect reference before assignment
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed argument error in calling native PyTorch checkpoint
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed linting errors for missing docstrings
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* Fix lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* bias GELU fusion consistency between checkpoint test and reference comparison
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 9b2fed51
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import math import math
import os import os
import sys
from typing import List, Optional from typing import List, Optional
import pytest import pytest
import copy import copy
...@@ -72,22 +73,27 @@ def get_causal_attn_mask(sq: int) -> torch.Tensor: ...@@ -72,22 +73,27 @@ def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor]) -> bool: def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
"""Ensures two lists are equal.""" """Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs." assert len(l1) == len(l2), "Unequal number of outputs."
for t1, t2 in zip(l1, l2): failed = False
assert torch.equal(t1, t2), "Output mismatch." failed_tensors = ""
for i, (t1, t2) in enumerate(zip(l1, l2)):
if not torch.equal(t1, t2):
failed = True
failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n"
assert not failed, "Output mismatches in:\n" + failed_tensors
def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) -> bool: def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) -> bool:
"""Ensures two lists are equal.""" """Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs." assert len(l1) == len(l2), "Unequal number of outputs."
for t1, t2 in zip(l1, l2): for i, (t1, t2) in enumerate(zip(l1, l2)):
result = torch.allclose(t1, t2, atol=atol) result = torch.allclose(t1, t2, atol=atol)
if not result: if not result:
diff = torch.abs(t1 - t2).flatten() diff = torch.abs(t1 - t2).flatten()
m = torch.argmax(diff) m = torch.argmax(diff)
msg = (f"Outputs not close enough." msg = (f"Outputs not close enough in tensor at idx={i}. "
f"Location of the maximum difference: {m.item()} " f"Location of the maximum difference: {m.item()} "
f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} " f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} "
f"(diff {diff[m].item()})." f"(diff {diff[m].item()})."
...@@ -457,7 +463,12 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par ...@@ -457,7 +463,12 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par
assert_all_equal(outputs, outputs_recompute) assert_all_equal(outputs, outputs_recompute)
def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False): def _test_e2e_full_recompute(
bs, dtype, config, fp8,
fp8_model_params=False,
recompute=False,
use_reentrant=True
):
reset_rng_states() reset_rng_states()
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
...@@ -494,8 +505,9 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, rec ...@@ -494,8 +505,9 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, rec
) )
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=use_reentrant
).cuda() ).cuda()
if use_reentrant:
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len) te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
...@@ -503,12 +515,13 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, rec ...@@ -503,12 +515,13 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, rec
if recompute: if recompute:
te_out = te_checkpoint( te_out = te_checkpoint(
block, block,
False, # distribute_saved_activations
get_dummy_cuda_rng_tracker,
None, # tp_group
te_inp_hidden_states, te_inp_hidden_states,
attention_mask=te_inp_attn_mask, attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False, checkpoint_core_attention=False,
distribute_saved_activations=False,
tp_group=None,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
use_reentrant=use_reentrant,
) )
else: else:
te_out = block( te_out = block(
...@@ -520,11 +533,17 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, rec ...@@ -520,11 +533,17 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, rec
loss.backward() loss.backward()
torch.cuda.synchronize() torch.cuda.synchronize()
outputs = [te_out, te_inp_hidden_states.grad] outputs = [te_out]
for p in block.parameters(): names = ["output"]
if use_reentrant:
outputs.append(te_inp_hidden_states.grad)
names.append("input")
for name, p in block.named_parameters():
if p.requires_grad: if p.requires_grad:
outputs.append(p.grad) outputs.append(p.grad)
return outputs names.append(name)
return outputs, names
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
...@@ -532,15 +551,27 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, rec ...@@ -532,15 +551,27 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, rec
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params): @pytest.mark.parametrize("use_reentrant", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params, use_reentrant):
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
config = model_configs[model] config = model_configs[model]
outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False) if not use_reentrant:
outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True) # Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion
assert_all_equal(outputs, outputs_recompute) os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0"
outputs, names = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params,
recompute=False, use_reentrant=use_reentrant)
outputs_recompute, _ = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params,
recompute=True, use_reentrant=use_reentrant)
if not use_reentrant:
# Reset bias+GELU fusion flag to avoid contaminating other tests
del os.environ["NVTE_BIAS_GELU_NVFUSION"]
assert_all_equal(outputs, outputs_recompute, names=names)
def _test_e2e_checkpointing_get_model(config, dtype): def _test_e2e_checkpointing_get_model(config, dtype):
......
...@@ -2472,9 +2472,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -2472,9 +2472,9 @@ class DotProductAttention(torch.nn.Module):
hidden_states = checkpoint( hidden_states = checkpoint(
custom_forward, custom_forward,
False, distribute_saved_activations=False,
self.get_rng_state_tracker, get_rng_state_tracker=self.get_rng_state_tracker,
self.tp_group, tp_group=self.tp_group,
*forward_args, *forward_args,
**forward_kwargs, **forward_kwargs,
) )
......
...@@ -3,12 +3,13 @@ ...@@ -3,12 +3,13 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Methods needed for distributed training (DP/TP).""" """Methods needed for distributed training (DP/TP)."""
from contextlib import contextmanager import warnings
from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from typing import Any, Dict, Union, Optional, Callable, Tuple from typing import Any, Dict, Union, Optional, Callable, Tuple
import torch import torch
from torch.cuda import _lazy_call from torch.cuda import _lazy_call
from torch.utils.checkpoint import detach_variable from torch.utils.checkpoint import detach_variable, noop_context_fn
from .utils import safely_set_viewless_tensor_data from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type from .constants import dist_group_type
...@@ -24,6 +25,8 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { ...@@ -24,6 +25,8 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"partition_stride": 1, "partition_stride": 1,
} }
_USE_REENTRANT_ACTIVATION_RECOMPUTE = True
_FP8_ACTIVATION_RECOMPUTE_ENABLED = False _FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False _FP8_ACTIVATION_RECOMPUTE_PHASE = False
...@@ -137,11 +140,7 @@ def gather_split_1d_tensor( ...@@ -137,11 +140,7 @@ def gather_split_1d_tensor(
return gathered return gathered
@contextmanager class activation_recompute_forward(AbstractContextManager, ContextDecorator):
def activation_recompute_forward(
activation_recompute: bool = False,
recompute_phase: bool = False,
) -> None:
"""Context manager used to control the forward runtime behavior when executed """Context manager used to control the forward runtime behavior when executed
under the `CheckpointFunction` function. For running FP8, the forward pass will under the `CheckpointFunction` function. For running FP8, the forward pass will
run without storing intermediate activations. Instead, the forward pass saves run without storing intermediate activations. Instead, the forward pass saves
...@@ -149,13 +148,24 @@ def activation_recompute_forward( ...@@ -149,13 +148,24 @@ def activation_recompute_forward(
retrieved, and the forward pass is computed again while tracking the intermediate retrieved, and the forward pass is computed again while tracking the intermediate
activations, followed by calculation of gradients using these values. activations, followed by calculation of gradients using these values.
""" """
def __init__(
self,
activation_recompute: bool = False,
recompute_phase: bool = False
):
super().__init__()
self.activation_recompute = activation_recompute
self.recompute_phase = recompute_phase
def __enter__(self):
global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
try:
_FP8_ACTIVATION_RECOMPUTE_ENABLED = ( _FP8_ACTIVATION_RECOMPUTE_ENABLED = (
activation_recompute and FP8GlobalStateManager.is_fp8_enabled()) self.activation_recompute and FP8GlobalStateManager.is_fp8_enabled()
_FP8_ACTIVATION_RECOMPUTE_PHASE = recompute_phase )
yield _FP8_ACTIVATION_RECOMPUTE_PHASE = self.recompute_phase
finally:
def __exit__(self, *exc_details):
global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
_FP8_ACTIVATION_RECOMPUTE_ENABLED = False _FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False _FP8_ACTIVATION_RECOMPUTE_PHASE = False
...@@ -170,7 +180,7 @@ def in_fp8_activation_recompute_phase() -> bool: ...@@ -170,7 +180,7 @@ def in_fp8_activation_recompute_phase() -> bool:
return _FP8_ACTIVATION_RECOMPUTE_PHASE return _FP8_ACTIVATION_RECOMPUTE_PHASE
class CheckpointFunction(torch.autograd.Function): class _CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with """This function is adapted from torch.utils.checkpoint with
two main changes: two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
...@@ -183,8 +193,9 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -183,8 +193,9 @@ class CheckpointFunction(torch.autograd.Function):
ctx, ctx,
run_function: Callable, run_function: Callable,
distribute_saved_activations: bool, distribute_saved_activations: bool,
get_cuda_rng_tracker: Callable, get_rng_state_tracker: Union[Callable, None],
tp_group: dist_group_type, tp_group: Union[dist_group_type, None],
context_fn: Union[Callable, None],
kwargs: Dict[str, Any], kwargs: Dict[str, Any],
*args: Tuple[torch.Tensor, ...], *args: Tuple[torch.Tensor, ...],
) -> Tuple[torch.Tensor, ...]: ) -> Tuple[torch.Tensor, ...]:
...@@ -196,9 +207,14 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -196,9 +207,14 @@ class CheckpointFunction(torch.autograd.Function):
# Copy the rng states. # Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() if get_rng_state_tracker is not None:
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
with torch.no_grad(): if context_fn is not None:
forward_ctx, recompute_ctx = context_fn()
else:
forward_ctx, recompute_ctx = noop_context_fn()
with torch.no_grad(), forward_ctx:
with activation_recompute_forward( with activation_recompute_forward(
activation_recompute=True, recompute_phase=False activation_recompute=True, recompute_phase=False
): ):
...@@ -220,8 +236,9 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -220,8 +236,9 @@ class CheckpointFunction(torch.autograd.Function):
tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args] tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args]
ctx.save_for_backward(*tensor_inputs) ctx.save_for_backward(*tensor_inputs)
ctx.get_cuda_rng_tracker = get_cuda_rng_tracker ctx.get_rng_state_tracker = get_rng_state_tracker
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.recompute_ctx = recompute_ctx
ctx.kwargs = kwargs ctx.kwargs = kwargs
return outputs return outputs
...@@ -242,7 +259,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -242,7 +259,7 @@ class CheckpointFunction(torch.autograd.Function):
for (t, arg) in zip(ctx.saved_tensors, ctx.inputs) for (t, arg) in zip(ctx.saved_tensors, ctx.inputs)
) )
get_cuda_rng_tracker = ctx.get_cuda_rng_tracker get_rng_state_tracker = ctx.get_rng_state_tracker
if ctx.distribute_saved_activations: if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data( safely_set_viewless_tensor_data(
...@@ -255,16 +272,18 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -255,16 +272,18 @@ class CheckpointFunction(torch.autograd.Function):
# Store the current states. # Store the current states.
bwd_cpu_rng_state = torch.get_rng_state() bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state() bwd_cuda_rng_state = torch.cuda.get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() if get_rng_state_tracker is not None:
bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
# Set the states to what it used to be before the forward pass. # Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state) torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state) _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
# Compute the forward pass. # Compute the forward pass.
detached_inputs = detach_variable(inputs) detached_inputs = detach_variable(inputs)
with torch.enable_grad(): with torch.enable_grad(), ctx.recompute_ctx:
with activation_recompute_forward( with activation_recompute_forward(
activation_recompute=True, recompute_phase=True activation_recompute=True, recompute_phase=True
): ):
...@@ -273,7 +292,8 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -273,7 +292,8 @@ class CheckpointFunction(torch.autograd.Function):
# Set the states back to what it was at the start of this function. # Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state) torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state) _set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor): if isinstance(outputs, torch.Tensor):
outputs = (outputs,) outputs = (outputs,)
...@@ -295,14 +315,168 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -295,14 +315,168 @@ class CheckpointFunction(torch.autograd.Function):
inp.grad if isinstance(inp, torch.Tensor) else None inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs for inp in detached_inputs
) )
return (None, None, None, None, None) + grads return (None, None, None, None, None, None) + grads
class _CheckpointFrame:
"""
Storage frame for forward RNG states and detached activations from the forward recompute.
"""
def __init__(
self,
recompute_fn: Callable,
get_rng_state_tracker: Callable
):
self.recompute_fn = recompute_fn
self.recomputed = []
self.count = 0
self.get_rng_state_tracker = get_rng_state_tracker
self.fwd_rng_states = None
self.bwd_rng_states = None
def cache_rng_states(self, forward=True):
"""Cache fwd/bwd RNG states in the frame to restore later."""
rng_states = (
torch.get_rng_state(),
torch.cuda.get_rng_state(),
)
if self.get_rng_state_tracker is not None:
rng_states += (self.get_rng_state_tracker().get_states(), )
if forward:
self.fwd_rng_states = rng_states
else:
self.bwd_rng_states = rng_states
def restore_rng_states(self, forward=True):
"""Restore fwd/bwd RNG states that were previously cached into the frame."""
if forward:
rng_states = self.fwd_rng_states
else:
rng_states = self.bwd_rng_states
torch.set_rng_state(rng_states[0])
_set_cuda_rng_state(rng_states[1])
if self.get_rng_state_tracker is not None:
self.get_rng_state_tracker().set_states(rng_states[2])
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks): # pylint: disable=too-few-public-methods
"""torch.autograd hook for packing/unpacking tensors during the activation recompute phase."""
def __init__(self, frame):
def pack_hook(x):
"""
Packing hook for each recomputed activation passed into the `ctx.save_for_backward()`
call in the forward recomputation.
"""
frame.recomputed.append(x.detach())
return x.detach()
def unpack_hook(x):
"""
No-op unpack hook that will never be called because the backward pass for the
forward recomputation is never triggered.
"""
return x
super().__init__(pack_hook, unpack_hook)
class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks): # pylint: disable=too-few-public-methods
"""torch.autograd hook for packing/unpacking tensors during the checkpointed forward pass."""
def __init__(self, frame, args, kwargs):
def pack_hook(x):
"""
Packing hook for each tensor passed into `ctx.save_for_backward()` call in the
forward pass. Since this is the first forward pass, we discard the tensor and instead
pack a placeholder tensor index into the autograd engine context.
"""
del x
idx = frame.count
frame.count += 1
return idx
def unpack_hook(idx):
"""
Unpacking hook for each tensor that comes out of the `ctx.saved_tensors` call in the
backward pass. The first time this is called, the _recomputation_hook will save all the
activation tensors from `ctx.save_for_backward()` in the forward recomputation into the
_CheckpointFrame. Subsequent calls will simply return the already recomputed activation
tensor at the given index of the _CheckpointFrame storage.
"""
if not frame.recomputed:
# Store current RNG states in the backward pass
frame.cache_rng_states(forward=False)
# Set RNG states to what we saved before the forward pass
frame.restore_rng_states(forward=True)
# Recompute the forward pass
with _recomputation_hook(frame):
frame.recompute_fn(*args, **kwargs)
# Restore RNG states back to the backward pass
frame.restore_rng_states(forward=False)
# Return the already recomputed activation tensor at the given index
activation = frame.recomputed[idx]
frame.recomputed[idx] = None
return activation
super().__init__(pack_hook, unpack_hook)
def use_reentrant_activation_recompute():
"""Returns `True` if activation recompute is using the 'reentrant' method."""
return _USE_REENTRANT_ACTIVATION_RECOMPUTE
def get_activation_recompute_contexts():
"""Returns context objects for the checkpointed forward pass and the forward recompute phase."""
forward_ctx = activation_recompute_forward(
activation_recompute=True,
recompute_phase=False,
)
recompute_ctx = activation_recompute_forward(
activation_recompute=True,
recompute_phase=True,
)
return forward_ctx, recompute_ctx
def _is_te_module(module):
"""
Check if given module is a Transformer Engine module that requires the TE checkpoint
implementation for activation recompute.
"""
from .module import LayerNorm, RMSNorm
from .module.base import TransformerEngineBaseModule
from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention
from .transformer import TransformerLayer
te_classes_list = [
LayerNorm,
RMSNorm,
TransformerEngineBaseModule,
UnfusedDotProductAttention,
DotProductAttention,
MultiheadAttention,
TransformerLayer,
]
is_te_module = False
for te_class in te_classes_list:
if isinstance(module, te_class):
is_te_module = True
break
return is_te_module
def checkpoint( def checkpoint(
function: Callable, function: Callable,
distribute_saved_activations: bool,
get_cuda_rng_tracker: Callable,
tp_group: dist_group_type,
*args: Tuple[torch.Tensor, ...], *args: Tuple[torch.Tensor, ...],
**kwargs: Dict[str, Any], **kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, ...]: ) -> Tuple[torch.Tensor, ...]:
...@@ -323,34 +497,116 @@ def checkpoint( ...@@ -323,34 +497,116 @@ def checkpoint(
PyTorch's :attr:`save_for_backward` method. :attr:`function` must be callable to produce PyTorch's :attr:`save_for_backward` method. :attr:`function` must be callable to produce
valid outputs with the inputs :attr:`args` and :attr:`kwargs`. valid outputs with the inputs :attr:`args` and :attr:`kwargs`.
.. warning::
`use_reentrant=False` does not support early stopping, and will execute the entire forward
pass for the checkpointed module when recomputing activations in the backward pass.
Parameters Parameters
---------- ----------
function: Callable function: Callable
pytorch module used to run the forward and backward passes using pytorch module used to run the forward and backward passes using
the specified :attr:`args` and :attr:`kwargs`. the specified :attr:`args` and :attr:`kwargs`.
distribute_saved_activations: bool distribute_saved_activations: bool, default = False
if set to `True`, the first tensor argument is distributed across the if set to `True` and `use_reentrant=True`, first tensor argument is distributed
specified tensor parallel group (`tp_group`) before saving it for the across the specified tensor parallel group (`tp_group`) before saving it for the
backward pass. backward pass. This has no effect when `use_reentrant=False`.
get_cuda_rng_tracker: `Callable` get_rng_state_tracker: `Callable`, default = None
python callable which returns an instance of :func:`CudaRNGStatesTracker`. python callable which returns an instance of :func:`CudaRNGStatesTracker`.
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = None
tensor parallel process group. tensor parallel process group. Used only when `distribute_saved_activations=True`
and `use_reentrant=True`. If `None`, it falls back to the default group.
use_reentrant : bool, default = True
perform checkpointing in reentrant mode.
args : tuple args : tuple
tuple of torch tensors for inputs to :attr:`function`. tuple of torch tensors for inputs to :attr:`function`.
kwargs : dict kwargs : dict
dictionary of string keys for keyword arguments to :attr:`function`. dictionary of string keys for keyword arguments to :attr:`function`.
""" """
# Pop out te.distributed.checkpoint() arguments
global _USE_REENTRANT_ACTIVATION_RECOMPUTE
_USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True)
distribute_saved_activations = kwargs.pop("distribute_saved_activations", False)
tp_group = kwargs.pop("tp_group", None)
get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None)
# Trigger the native PyTorch checkpoint if:
# 1. `function` is a `torch.nn.Module`
# AND
# 2. `function` is NOT a TE module
context_fn = kwargs.pop("context_fn", noop_context_fn)
determinism_check = kwargs.pop("determinism_check", "default")
debug = kwargs.pop("debug", False)
if isinstance(function, torch.nn.Module) and not _is_te_module(function):
return torch.utils.checkpoint.checkpoint(
function,
*args,
use_reentrant=_USE_REENTRANT_ACTIVATION_RECOMPUTE,
context_fn=context_fn,
determinism_check=determinism_check,
debug=debug,
**kwargs
)
return CheckpointFunction.apply( # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments
# and execute TE's own checkpointing
# NOTE: This logic uses the TE checkpoint on all custom callable `function` handles because we
# cannot be sure there are no TE modules inside the function. It also means we might run
# the TE checkpoint for non-TE modules, so the TE checkpoint has to support a potential
# user context function.
del determinism_check, debug
if _USE_REENTRANT_ACTIVATION_RECOMPUTE:
# If saved activations need to be distributed but there is no process group,
# default to the world group.
if distribute_saved_activations:
assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group
# Make sure at least one tensor input has `requires_grad=True`
input_requires_grad = False
for arg in args:
if isinstance(arg, torch.Tensor) and arg.requires_grad:
input_requires_grad = True
break
assert input_requires_grad, (
"`use_reentrant=True` requires at least one input tensor with `requires_grad=True`."
)
return _CheckpointFunction.apply(
function, function,
distribute_saved_activations, distribute_saved_activations,
get_cuda_rng_tracker, get_rng_state_tracker,
tp_group, tp_group,
context_fn,
kwargs, kwargs,
*args, *args,
) )
if distribute_saved_activations:
warnings.warn(
"`distribute_saved_activations=True` has no effect when `use_reentrant=False`. "
"The non-reentrant checkpoint implementation does not manually store forward "
"inputs for the activation recompute in the backward pass, and instead leverages "
"the autograd engine's pack/unpack hooks."
)
user_forward_ctx, user_recompute_ctx = context_fn()
te_forward_ctx, te_recompute_ctx = get_activation_recompute_contexts()
def recompute_fn(*args, **kwargs):
with torch.autograd.enable_grad(), te_recompute_ctx, user_recompute_ctx:
function(*args, **kwargs)
# Initialize a new checkpoint frame for each new forward pass.
new_frame = _CheckpointFrame(
recompute_fn,
get_rng_state_tracker,
)
new_frame.cache_rng_states(forward=True)
with _checkpoint_hook(new_frame, args, kwargs), te_forward_ctx, user_forward_ctx:
out = function(*args, **kwargs)
return out
class CudaRNGStatesTracker: class CudaRNGStatesTracker:
""" """
...@@ -421,7 +677,7 @@ class CudaRNGStatesTracker: ...@@ -421,7 +677,7 @@ class CudaRNGStatesTracker:
_set_cuda_rng_state(orig_rng_state) _set_cuda_rng_state(orig_rng_state)
@contextmanager @contextmanager
def fork(self, name: str = "model-parallel-rng") -> None: def fork(self, name: str = "model-parallel-rng"):
""" """
Fork the cuda rng state, perform operations, and exit with Fork the cuda rng state, perform operations, and exit with
the original state. the original state.
......
...@@ -43,6 +43,7 @@ from ..distributed import ( ...@@ -43,6 +43,7 @@ from ..distributed import (
gather_along_first_dim, gather_along_first_dim,
is_fp8_activation_recompute_enabled, is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase, in_fp8_activation_recompute_phase,
use_reentrant_activation_recompute,
) )
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
...@@ -1415,6 +1416,11 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1415,6 +1416,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
is_first_microbatch is_first_microbatch
) )
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if (self.bias_gelu_nvfusion
and not use_reentrant_activation_recompute()):
self.bias_gelu_nvfusion = False
from ..cpu_offload import CPUOffloadEnabled from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled(): if torch.is_grad_enabled():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment