Unverified Commit 82bc797f authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Non-reentrant mode for activation recompute (#670)



* added non-reentrant mode support to TE checkpoint
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* updated get_cuda_rng_tracker kwarg to get_rng_state_tracker to remain consistent with other TE API
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* docstring cleanup
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added mechanism to disable bias_gelu_nvfusion in LayerNormMLP when checkpointing in non-reentrant mode
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* refactored checkpoint and recompute hook names to match PyTorch implementation
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* Fixed incorrect reference before assignment
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed argument error in calling native PyTorch checkpoint
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed linting errors for missing docstrings
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* Fix lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* bias GELU fusion consistency between checkpoint test and reference comparison
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 9b2fed51
......@@ -4,6 +4,7 @@
import math
import os
import sys
from typing import List, Optional
import pytest
import copy
......@@ -72,22 +73,27 @@ def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor]) -> bool:
def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for t1, t2 in zip(l1, l2):
assert torch.equal(t1, t2), "Output mismatch."
failed = False
failed_tensors = ""
for i, (t1, t2) in enumerate(zip(l1, l2)):
if not torch.equal(t1, t2):
failed = True
failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n"
assert not failed, "Output mismatches in:\n" + failed_tensors
def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for t1, t2 in zip(l1, l2):
for i, (t1, t2) in enumerate(zip(l1, l2)):
result = torch.allclose(t1, t2, atol=atol)
if not result:
diff = torch.abs(t1 - t2).flatten()
m = torch.argmax(diff)
msg = (f"Outputs not close enough."
msg = (f"Outputs not close enough in tensor at idx={i}. "
f"Location of the maximum difference: {m.item()} "
f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} "
f"(diff {diff[m].item()})."
......@@ -457,7 +463,12 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par
assert_all_equal(outputs, outputs_recompute)
def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
def _test_e2e_full_recompute(
bs, dtype, config, fp8,
fp8_model_params=False,
recompute=False,
use_reentrant=True
):
reset_rng_states()
FP8GlobalStateManager.reset()
......@@ -494,21 +505,23 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, rec
)
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=use_reentrant
).cuda()
te_inp_hidden_states.retain_grad()
if use_reentrant:
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
with fp8_autocast(enabled=fp8):
if recompute:
te_out = te_checkpoint(
block,
False, # distribute_saved_activations
get_dummy_cuda_rng_tracker,
None, # tp_group
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
distribute_saved_activations=False,
tp_group=None,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
use_reentrant=use_reentrant,
)
else:
te_out = block(
......@@ -520,11 +533,17 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, rec
loss.backward()
torch.cuda.synchronize()
outputs = [te_out, te_inp_hidden_states.grad]
for p in block.parameters():
outputs = [te_out]
names = ["output"]
if use_reentrant:
outputs.append(te_inp_hidden_states.grad)
names.append("input")
for name, p in block.named_parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
names.append(name)
return outputs, names
@pytest.mark.parametrize("dtype", param_types)
......@@ -532,15 +551,27 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, rec
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params):
@pytest.mark.parametrize("use_reentrant", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params, use_reentrant):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
assert_all_equal(outputs, outputs_recompute)
if not use_reentrant:
# Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion
os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0"
outputs, names = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params,
recompute=False, use_reentrant=use_reentrant)
outputs_recompute, _ = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params,
recompute=True, use_reentrant=use_reentrant)
if not use_reentrant:
# Reset bias+GELU fusion flag to avoid contaminating other tests
del os.environ["NVTE_BIAS_GELU_NVFUSION"]
assert_all_equal(outputs, outputs_recompute, names=names)
def _test_e2e_checkpointing_get_model(config, dtype):
......
......@@ -2472,9 +2472,9 @@ class DotProductAttention(torch.nn.Module):
hidden_states = checkpoint(
custom_forward,
False,
self.get_rng_state_tracker,
self.tp_group,
distribute_saved_activations=False,
get_rng_state_tracker=self.get_rng_state_tracker,
tp_group=self.tp_group,
*forward_args,
**forward_kwargs,
)
......
......@@ -3,12 +3,13 @@
# See LICENSE for license information.
"""Methods needed for distributed training (DP/TP)."""
from contextlib import contextmanager
import warnings
from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from typing import Any, Dict, Union, Optional, Callable, Tuple
import torch
from torch.cuda import _lazy_call
from torch.utils.checkpoint import detach_variable
from torch.utils.checkpoint import detach_variable, noop_context_fn
from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type
......@@ -24,6 +25,8 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"partition_stride": 1,
}
_USE_REENTRANT_ACTIVATION_RECOMPUTE = True
_FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False
......@@ -137,11 +140,7 @@ def gather_split_1d_tensor(
return gathered
@contextmanager
def activation_recompute_forward(
activation_recompute: bool = False,
recompute_phase: bool = False,
) -> None:
class activation_recompute_forward(AbstractContextManager, ContextDecorator):
"""Context manager used to control the forward runtime behavior when executed
under the `CheckpointFunction` function. For running FP8, the forward pass will
run without storing intermediate activations. Instead, the forward pass saves
......@@ -149,13 +148,24 @@ def activation_recompute_forward(
retrieved, and the forward pass is computed again while tracking the intermediate
activations, followed by calculation of gradients using these values.
"""
global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
try:
def __init__(
self,
activation_recompute: bool = False,
recompute_phase: bool = False
):
super().__init__()
self.activation_recompute = activation_recompute
self.recompute_phase = recompute_phase
def __enter__(self):
global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
_FP8_ACTIVATION_RECOMPUTE_ENABLED = (
activation_recompute and FP8GlobalStateManager.is_fp8_enabled())
_FP8_ACTIVATION_RECOMPUTE_PHASE = recompute_phase
yield
finally:
self.activation_recompute and FP8GlobalStateManager.is_fp8_enabled()
)
_FP8_ACTIVATION_RECOMPUTE_PHASE = self.recompute_phase
def __exit__(self, *exc_details):
global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
_FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False
......@@ -170,7 +180,7 @@ def in_fp8_activation_recompute_phase() -> bool:
return _FP8_ACTIVATION_RECOMPUTE_PHASE
class CheckpointFunction(torch.autograd.Function):
class _CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
......@@ -183,8 +193,9 @@ class CheckpointFunction(torch.autograd.Function):
ctx,
run_function: Callable,
distribute_saved_activations: bool,
get_cuda_rng_tracker: Callable,
tp_group: dist_group_type,
get_rng_state_tracker: Union[Callable, None],
tp_group: Union[dist_group_type, None],
context_fn: Union[Callable, None],
kwargs: Dict[str, Any],
*args: Tuple[torch.Tensor, ...],
) -> Tuple[torch.Tensor, ...]:
......@@ -196,9 +207,14 @@ class CheckpointFunction(torch.autograd.Function):
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
with torch.no_grad():
if get_rng_state_tracker is not None:
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
if context_fn is not None:
forward_ctx, recompute_ctx = context_fn()
else:
forward_ctx, recompute_ctx = noop_context_fn()
with torch.no_grad(), forward_ctx:
with activation_recompute_forward(
activation_recompute=True, recompute_phase=False
):
......@@ -220,8 +236,9 @@ class CheckpointFunction(torch.autograd.Function):
tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args]
ctx.save_for_backward(*tensor_inputs)
ctx.get_cuda_rng_tracker = get_cuda_rng_tracker
ctx.get_rng_state_tracker = get_rng_state_tracker
ctx.tp_group = tp_group
ctx.recompute_ctx = recompute_ctx
ctx.kwargs = kwargs
return outputs
......@@ -242,7 +259,7 @@ class CheckpointFunction(torch.autograd.Function):
for (t, arg) in zip(ctx.saved_tensors, ctx.inputs)
)
get_cuda_rng_tracker = ctx.get_cuda_rng_tracker
get_rng_state_tracker = ctx.get_rng_state_tracker
if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data(
......@@ -255,16 +272,18 @@ class CheckpointFunction(torch.autograd.Function):
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
if get_rng_state_tracker is not None:
bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
# Compute the forward pass.
detached_inputs = detach_variable(inputs)
with torch.enable_grad():
with torch.enable_grad(), ctx.recompute_ctx:
with activation_recompute_forward(
activation_recompute=True, recompute_phase=True
):
......@@ -273,7 +292,8 @@ class CheckpointFunction(torch.autograd.Function):
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
......@@ -295,14 +315,168 @@ class CheckpointFunction(torch.autograd.Function):
inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs
)
return (None, None, None, None, None) + grads
return (None, None, None, None, None, None) + grads
class _CheckpointFrame:
"""
Storage frame for forward RNG states and detached activations from the forward recompute.
"""
def __init__(
self,
recompute_fn: Callable,
get_rng_state_tracker: Callable
):
self.recompute_fn = recompute_fn
self.recomputed = []
self.count = 0
self.get_rng_state_tracker = get_rng_state_tracker
self.fwd_rng_states = None
self.bwd_rng_states = None
def cache_rng_states(self, forward=True):
"""Cache fwd/bwd RNG states in the frame to restore later."""
rng_states = (
torch.get_rng_state(),
torch.cuda.get_rng_state(),
)
if self.get_rng_state_tracker is not None:
rng_states += (self.get_rng_state_tracker().get_states(), )
if forward:
self.fwd_rng_states = rng_states
else:
self.bwd_rng_states = rng_states
def restore_rng_states(self, forward=True):
"""Restore fwd/bwd RNG states that were previously cached into the frame."""
if forward:
rng_states = self.fwd_rng_states
else:
rng_states = self.bwd_rng_states
torch.set_rng_state(rng_states[0])
_set_cuda_rng_state(rng_states[1])
if self.get_rng_state_tracker is not None:
self.get_rng_state_tracker().set_states(rng_states[2])
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks): # pylint: disable=too-few-public-methods
"""torch.autograd hook for packing/unpacking tensors during the activation recompute phase."""
def __init__(self, frame):
def pack_hook(x):
"""
Packing hook for each recomputed activation passed into the `ctx.save_for_backward()`
call in the forward recomputation.
"""
frame.recomputed.append(x.detach())
return x.detach()
def unpack_hook(x):
"""
No-op unpack hook that will never be called because the backward pass for the
forward recomputation is never triggered.
"""
return x
super().__init__(pack_hook, unpack_hook)
class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks): # pylint: disable=too-few-public-methods
"""torch.autograd hook for packing/unpacking tensors during the checkpointed forward pass."""
def __init__(self, frame, args, kwargs):
def pack_hook(x):
"""
Packing hook for each tensor passed into `ctx.save_for_backward()` call in the
forward pass. Since this is the first forward pass, we discard the tensor and instead
pack a placeholder tensor index into the autograd engine context.
"""
del x
idx = frame.count
frame.count += 1
return idx
def unpack_hook(idx):
"""
Unpacking hook for each tensor that comes out of the `ctx.saved_tensors` call in the
backward pass. The first time this is called, the _recomputation_hook will save all the
activation tensors from `ctx.save_for_backward()` in the forward recomputation into the
_CheckpointFrame. Subsequent calls will simply return the already recomputed activation
tensor at the given index of the _CheckpointFrame storage.
"""
if not frame.recomputed:
# Store current RNG states in the backward pass
frame.cache_rng_states(forward=False)
# Set RNG states to what we saved before the forward pass
frame.restore_rng_states(forward=True)
# Recompute the forward pass
with _recomputation_hook(frame):
frame.recompute_fn(*args, **kwargs)
# Restore RNG states back to the backward pass
frame.restore_rng_states(forward=False)
# Return the already recomputed activation tensor at the given index
activation = frame.recomputed[idx]
frame.recomputed[idx] = None
return activation
super().__init__(pack_hook, unpack_hook)
def use_reentrant_activation_recompute():
"""Returns `True` if activation recompute is using the 'reentrant' method."""
return _USE_REENTRANT_ACTIVATION_RECOMPUTE
def get_activation_recompute_contexts():
"""Returns context objects for the checkpointed forward pass and the forward recompute phase."""
forward_ctx = activation_recompute_forward(
activation_recompute=True,
recompute_phase=False,
)
recompute_ctx = activation_recompute_forward(
activation_recompute=True,
recompute_phase=True,
)
return forward_ctx, recompute_ctx
def _is_te_module(module):
"""
Check if given module is a Transformer Engine module that requires the TE checkpoint
implementation for activation recompute.
"""
from .module import LayerNorm, RMSNorm
from .module.base import TransformerEngineBaseModule
from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention
from .transformer import TransformerLayer
te_classes_list = [
LayerNorm,
RMSNorm,
TransformerEngineBaseModule,
UnfusedDotProductAttention,
DotProductAttention,
MultiheadAttention,
TransformerLayer,
]
is_te_module = False
for te_class in te_classes_list:
if isinstance(module, te_class):
is_te_module = True
break
return is_te_module
def checkpoint(
function: Callable,
distribute_saved_activations: bool,
get_cuda_rng_tracker: Callable,
tp_group: dist_group_type,
*args: Tuple[torch.Tensor, ...],
**kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, ...]:
......@@ -323,34 +497,116 @@ def checkpoint(
PyTorch's :attr:`save_for_backward` method. :attr:`function` must be callable to produce
valid outputs with the inputs :attr:`args` and :attr:`kwargs`.
.. warning::
`use_reentrant=False` does not support early stopping, and will execute the entire forward
pass for the checkpointed module when recomputing activations in the backward pass.
Parameters
----------
function: Callable
pytorch module used to run the forward and backward passes using
the specified :attr:`args` and :attr:`kwargs`.
distribute_saved_activations: bool
if set to `True`, the first tensor argument is distributed across the
specified tensor parallel group (`tp_group`) before saving it for the
backward pass.
get_cuda_rng_tracker: `Callable`
distribute_saved_activations: bool, default = False
if set to `True` and `use_reentrant=True`, first tensor argument is distributed
across the specified tensor parallel group (`tp_group`) before saving it for the
backward pass. This has no effect when `use_reentrant=False`.
get_rng_state_tracker: `Callable`, default = None
python callable which returns an instance of :func:`CudaRNGStatesTracker`.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
tp_group : ProcessGroup, default = None
tensor parallel process group. Used only when `distribute_saved_activations=True`
and `use_reentrant=True`. If `None`, it falls back to the default group.
use_reentrant : bool, default = True
perform checkpointing in reentrant mode.
args : tuple
tuple of torch tensors for inputs to :attr:`function`.
kwargs : dict
dictionary of string keys for keyword arguments to :attr:`function`.
"""
# Pop out te.distributed.checkpoint() arguments
global _USE_REENTRANT_ACTIVATION_RECOMPUTE
_USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True)
distribute_saved_activations = kwargs.pop("distribute_saved_activations", False)
tp_group = kwargs.pop("tp_group", None)
get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None)
# Trigger the native PyTorch checkpoint if:
# 1. `function` is a `torch.nn.Module`
# AND
# 2. `function` is NOT a TE module
context_fn = kwargs.pop("context_fn", noop_context_fn)
determinism_check = kwargs.pop("determinism_check", "default")
debug = kwargs.pop("debug", False)
if isinstance(function, torch.nn.Module) and not _is_te_module(function):
return torch.utils.checkpoint.checkpoint(
function,
*args,
use_reentrant=_USE_REENTRANT_ACTIVATION_RECOMPUTE,
context_fn=context_fn,
determinism_check=determinism_check,
debug=debug,
**kwargs
)
# Otherwise discard unused te.utils.checkpoint.checkpoint() arguments
# and execute TE's own checkpointing
# NOTE: This logic uses the TE checkpoint on all custom callable `function` handles because we
# cannot be sure there are no TE modules inside the function. It also means we might run
# the TE checkpoint for non-TE modules, so the TE checkpoint has to support a potential
# user context function.
del determinism_check, debug
if _USE_REENTRANT_ACTIVATION_RECOMPUTE:
# If saved activations need to be distributed but there is no process group,
# default to the world group.
if distribute_saved_activations:
assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group
# Make sure at least one tensor input has `requires_grad=True`
input_requires_grad = False
for arg in args:
if isinstance(arg, torch.Tensor) and arg.requires_grad:
input_requires_grad = True
break
assert input_requires_grad, (
"`use_reentrant=True` requires at least one input tensor with `requires_grad=True`."
)
return CheckpointFunction.apply(
function,
distribute_saved_activations,
get_cuda_rng_tracker,
tp_group,
kwargs,
*args,
return _CheckpointFunction.apply(
function,
distribute_saved_activations,
get_rng_state_tracker,
tp_group,
context_fn,
kwargs,
*args,
)
if distribute_saved_activations:
warnings.warn(
"`distribute_saved_activations=True` has no effect when `use_reentrant=False`. "
"The non-reentrant checkpoint implementation does not manually store forward "
"inputs for the activation recompute in the backward pass, and instead leverages "
"the autograd engine's pack/unpack hooks."
)
user_forward_ctx, user_recompute_ctx = context_fn()
te_forward_ctx, te_recompute_ctx = get_activation_recompute_contexts()
def recompute_fn(*args, **kwargs):
with torch.autograd.enable_grad(), te_recompute_ctx, user_recompute_ctx:
function(*args, **kwargs)
# Initialize a new checkpoint frame for each new forward pass.
new_frame = _CheckpointFrame(
recompute_fn,
get_rng_state_tracker,
)
new_frame.cache_rng_states(forward=True)
with _checkpoint_hook(new_frame, args, kwargs), te_forward_ctx, user_forward_ctx:
out = function(*args, **kwargs)
return out
class CudaRNGStatesTracker:
"""
......@@ -421,7 +677,7 @@ class CudaRNGStatesTracker:
_set_cuda_rng_state(orig_rng_state)
@contextmanager
def fork(self, name: str = "model-parallel-rng") -> None:
def fork(self, name: str = "model-parallel-rng"):
"""
Fork the cuda rng state, perform operations, and exit with
the original state.
......
......@@ -43,6 +43,7 @@ from ..distributed import (
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
use_reentrant_activation_recompute,
)
from .. import cpp_extensions as tex
......@@ -1415,6 +1416,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
is_first_microbatch
)
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if (self.bias_gelu_nvfusion
and not use_reentrant_activation_recompute()):
self.bias_gelu_nvfusion = False
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment