Unverified Commit 05bfa3f8 authored by Jaime's avatar Jaime Committed by GitHub
Browse files

[PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP with...


[PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP with checkpoint flag (#2311)

* custom tests for selective activation checkpointing for layernorm mlp
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* add selective layernorm mlp to te.pytorch
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* update test and fix SLNMLP bug
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* implement slnmlp
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* fix tests pointed out by greptile app bot, still pass
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* minor formatting change in tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarJaime <102792198+jaimec00@users.noreply.github.com>
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* remove duplicate import in test/pytorch/selective_layernorm_mlp/test_recipe.py
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* clean up tests, remove unused imports
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* remove unused paths in test_deffered_init
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* fix issue with zero_centered_gamma in test_numerics reference implementation
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* clean up tests
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* make comparison.py more extensive, cleaner output
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* fix small typo in tests/pytorch/selective_layernorm_mlp/compare.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarJaime <102792198+jaimec00@users.noreply.github.com>
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* fix typo by grepbot in compare.py
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* make selectiuve activation checkpointing optional in slnmlp via checkpoint flag
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* add comments to clarify logic
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* add checkpoint param to pytests, change compare.py to compare checkppoint=False vs checkpoint=True, skip cuda graph tests for checkpoint=True
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* refactor tests to call modified LayerNormMLP
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* refactor to implement selective activation checkpointing directly into LayerNormMLP, also fix bug to reach cleanup logic in fwd
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix skip explanation for cuda_graphs.py
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* make _recompute deal with lists instead of tuples
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix MOST cuda graph failures by initializing identical quantizers during fwd. Float8CurrentScaling with bf16 and fp16 still fail with checkpointing
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix cuda graphs issue, all tests pass now
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix small logic bugs, clean up
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* integrate tests into main testing scripts
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* incorporate rng state tracking in checkpointing
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* clean up tests
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* fix return type mismatches
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* remove checkpoint test from test_recipe, add sperate test in test_numerics
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor typo fix
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarJaime <102792198+jaimec00@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* clear up assertions in tests/pytorch/layernorm_mlp/test_selective_activation_checkpoint.py
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add license and copyright info
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* fix lint issues in layernorm_mlp
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* fix cpu_offload_v1 error
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* possibly fix recomputation in cuda graph bug
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* skip cuda graphs test for SLNMLP with SM>=10.0 and using delayed scaling
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix typo for setting IS_FIRST_FP8_MODULE
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

---------
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>
Signed-off-by: default avatarJaime <102792198+jaimec00@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 30c0120b
...@@ -1030,6 +1030,7 @@ def test_layernorm_mlp(): ...@@ -1030,6 +1030,7 @@ def test_layernorm_mlp():
{"return_bias": True}, {"return_bias": True},
{"return_layernorm_output": True}, {"return_layernorm_output": True},
{"delay_wgrad_compute": True}, {"delay_wgrad_compute": True},
{"checkpoint": True},
] ]
for kwargs in kwargs_list: for kwargs in kwargs_list:
......
...@@ -13,7 +13,7 @@ import transformer_engine.pytorch as te ...@@ -13,7 +13,7 @@ import transformer_engine.pytorch as te
""" """
Distributed numerics tests Distributed numerics tests
These tests test the numerical corectness of the TransformerEngine layers. These tests test the numerical correctness of the TransformerEngine layers.
Tests are parametrized by the layer and fp8 precision. Tests are parametrized by the layer and fp8 precision.
One test consists of running multiple configurations from file run_numerics.py One test consists of running multiple configurations from file run_numerics.py
Such design is due to the fact the initialization of one test is long Such design is due to the fact the initialization of one test is long
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
from transformer_engine.pytorch import LayerNormMLP
import pytest
torch.manual_seed(1234)
device = torch.device("cuda")
class _Sequential(torch.nn.Sequential):
"""Sequential model that forwards keyword arguments to modules"""
def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
x = input_
for module in self:
x = module(x, **kwargs)
return x
class ModelConfig:
def __init__(
self,
hidden_size: int = 128,
ffn_hidden_size: int = 512,
layers: int = 1,
):
self._hidden_size = hidden_size
self._ffn_hidden_size = ffn_hidden_size
self._layers = layers
def build(self):
ln_list, sln_list = [], []
for _ in range(self._layers):
ln = LayerNormMLP(self._hidden_size, self._ffn_hidden_size, checkpoint=False).to(device)
sln = LayerNormMLP(self._hidden_size, self._ffn_hidden_size, checkpoint=True).to(device)
with torch.no_grad():
sln.layer_norm_weight = torch.nn.Parameter(ln.layer_norm_weight.clone())
sln.layer_norm_bias = torch.nn.Parameter(ln.layer_norm_bias.clone())
sln.fc1_weight = torch.nn.Parameter(ln.fc1_weight.clone())
sln.fc2_weight = torch.nn.Parameter(ln.fc2_weight.clone())
sln.fc1_bias = torch.nn.Parameter(ln.fc1_bias.clone())
sln.fc2_bias = torch.nn.Parameter(ln.fc2_bias.clone())
ln_list.append(ln)
sln_list.append(sln)
ln_model = _Sequential(*ln_list)
sln_model = _Sequential(*sln_list)
return ln_model, sln_model
config = {
"small": ModelConfig(128, 512, 12),
"medium": ModelConfig(512, 2048, 12),
"large": ModelConfig(1024, 4096, 12),
"huge": ModelConfig(2048, 8192, 12),
}
seq_sizes = [2**7, 2**10, 2**14, 2**16]
def _warmup(model, tensor):
for _ in range(3):
model(tensor).sum().backward()
def _run_fwd(model, tensor):
torch.cuda.reset_peak_memory_stats(device)
start_time, end_time = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
enable_timing=True
)
torch.cuda.synchronize()
start_mem = torch.cuda.memory_allocated(device)
start_time.record()
out = model(tensor)
end_time.record()
end_time.synchronize()
elapsed = start_time.elapsed_time(end_time)
peak_mem = torch.cuda.max_memory_allocated(device)
mem = float(peak_mem - start_mem)
return out, elapsed, mem
def _run_bwd(model, out):
model.zero_grad(set_to_none=False)
loss = out.sum()
torch.cuda.reset_peak_memory_stats(device)
start_time, end_time = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
enable_timing=True
)
torch.cuda.synchronize()
start_mem = torch.cuda.memory_allocated(device)
start_time.record()
loss.backward()
end_time.record()
end_time.synchronize()
elapsed = start_time.elapsed_time(end_time)
peak_mem = torch.cuda.max_memory_allocated(device)
mem = float(peak_mem - start_mem)
param_grads = _collect_param_grads(model)
return param_grads, elapsed, mem
def _max_diff(ref, other):
"""Return max absolute difference between two tensors or collections."""
if ref is None or other is None:
return 0.0
if isinstance(ref, (list, tuple)):
diffs = [_max_diff(r, o) for r, o in zip(ref, other)]
return max(diffs) if diffs else 0.0
return torch.max(torch.abs(ref.detach() - other.detach())).item()
def _collect_param_grads(model):
grads = {}
for name, param in model.named_parameters():
if param.grad is None:
continue
key = _param_key(name)
if key is not None:
grads[key] = param.grad.detach().clone()
return grads
def _param_key(name):
return name.split(".")[-1]
@pytest.mark.parametrize("size", config.keys())
@pytest.mark.parametrize("seq_size", seq_sizes)
def test_selective_activation_checkpoint(size, seq_size):
ln_model, sln_model = config[size].build()
data = torch.randn((seq_size, config[size]._hidden_size), device=device)
_warmup(ln_model, data)
ln_fwd_out, ln_fwd_time, ln_fwd_mem = _run_fwd(ln_model, data)
ln_grads, ln_bwd_time, ln_bwd_mem = _run_bwd(ln_model, ln_fwd_out)
_warmup(sln_model, data)
sln_fwd_out, sln_fwd_time, sln_fwd_mem = _run_fwd(sln_model, data)
sln_grads, sln_bwd_time, sln_bwd_mem = _run_bwd(sln_model, sln_fwd_out)
assert ln_fwd_mem > 6 * sln_fwd_mem, (
"selective activation checkpointing does not reduce forward memory by 6X, only by"
f" {ln_fwd_mem/sln_fwd_mem}!"
)
assert ln_bwd_time < sln_bwd_time, (
"selective activation activation checkpointing backward pass is NOT slower than native!"
f" got Native LayerNormMLP Backward Time: {ln_bwd_time} ms and Selective Activation"
f" Checkpointed LayerNormMLP Backward Time: {sln_bwd_time} ms"
)
diff = _max_diff(ln_fwd_out, sln_fwd_out)
assert diff == 0.0, f"outputs are not equal! maximum difference {diff}"
for key in [
"layer_norm_weight",
"layer_norm_bias",
"fc1_weight",
"fc1_bias",
"fc2_weight",
"fc2_bias",
]:
diff = _max_diff(ln_grads[key], sln_grads[key])
assert diff == 0.0, f"gradients for {key} are not equal! maximum difference: {diff}"
...@@ -190,7 +190,8 @@ _test_cuda_graphs_modules: List[str] = [ ...@@ -190,7 +190,8 @@ _test_cuda_graphs_modules: List[str] = [
# creating TMA descriptor for MXFP8 quantization. # creating TMA descriptor for MXFP8 quantization.
"linear", "linear",
"transformer", "transformer",
"layernorm_mlp", "layernorm_mlp_nocheckpoint",
"layernorm_mlp_checkpoint",
"layernorm_linear", "layernorm_linear",
"mha", "mha",
"linear_op", "linear_op",
...@@ -232,12 +233,23 @@ def _test_cuda_graphs( ...@@ -232,12 +233,23 @@ def _test_cuda_graphs(
) )
for _ in range(num_layers) for _ in range(num_layers)
] ]
elif module == "layernorm_mlp": elif module == "layernorm_mlp_nocheckpoint":
modules = [ modules = [
LayerNormMLP( LayerNormMLP(
model_config.hidden_size, model_config.hidden_size,
model_config.hidden_size, model_config.hidden_size,
params_dtype=dtype, params_dtype=dtype,
checkpoint=False,
)
for _ in range(num_layers)
]
elif module == "layernorm_mlp_checkpoint":
modules = [
LayerNormMLP(
model_config.hidden_size,
model_config.hidden_size,
params_dtype=dtype,
checkpoint=True,
) )
for _ in range(num_layers) for _ in range(num_layers)
] ]
...@@ -376,6 +388,17 @@ def test_make_graphed_callables( ...@@ -376,6 +388,17 @@ def test_make_graphed_callables(
) )
if fp8_params: if fp8_params:
pytest.skip("NVFP4 params not supported") pytest.skip("NVFP4 params not supported")
if (
fp8
and fp8_recipe.delayed()
and torch.cuda.get_device_capability() >= (10, 0)
and module == "layernorm_mlp_checkpoint"
):
pytest.skip(
"CUDA graphs not supported for LayerNormMLP "
"with checkpoint=True, SM>=10, "
"and DelayedScaling recipe"
)
# Run model with different CUDA graph settings. # Run model with different CUDA graph settings.
model_config = model_configs[model_config] model_config = model_configs[model_config]
...@@ -402,7 +425,8 @@ def test_make_graphed_callables( ...@@ -402,7 +425,8 @@ def test_make_graphed_callables(
_test_make_graphed_callables_with_fp8_weight_caching_modules = [ _test_make_graphed_callables_with_fp8_weight_caching_modules = [
"transformer", "transformer",
"layernorm_mlp", "layernorm_mlp_nocheckpoint",
"layernorm_mlp_checkpoint",
"layernorm_linear", "layernorm_linear",
"linear", "linear",
"mha", "mha",
......
...@@ -185,7 +185,7 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: ...@@ -185,7 +185,7 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
return dict(rtol=1e-3, atol=1e-5) return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5) return dict(rtol=1.6e-2, atol=1e-5)
raise ValueError(f"Unsuppored dtype ({dtype})") raise ValueError(f"Unsupported dtype ({dtype})")
def assert_allclose( def assert_allclose(
...@@ -1363,7 +1363,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): ...@@ -1363,7 +1363,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe) te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe)
te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe) te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe)
# Shoule be bit-wise match # Should be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0) torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
...@@ -1696,7 +1696,11 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret ...@@ -1696,7 +1696,11 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
@pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_mlp_accuracy_delay_wgrad_compute( def test_layernorm_mlp_accuracy_delay_wgrad_compute(
dtype, bs, model, bias, fuse_wgrad_accumulation dtype,
bs,
model,
bias,
fuse_wgrad_accumulation,
): ):
config = model_configs[model] config = model_configs[model]
...@@ -1747,6 +1751,58 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ...@@ -1747,6 +1751,58 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
torch.testing.assert_close(o, o_ref, rtol=0, atol=0) torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", [2])
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_mlp_accuracy_checkpoint(
dtype,
bs,
model,
bias,
):
config = model_configs[model]
ln_mlp = LayerNormMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
bias=bias,
params_dtype=dtype,
device="cuda",
checkpoint=True,
).eval()
ln_mlp_ref = LayerNormMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
bias=bias,
params_dtype=dtype,
device="cuda",
checkpoint=False,
).eval()
# Share params
with torch.no_grad():
ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
if bias:
ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone())
ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone())
te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=False)
te_outputs_ref = _test_granular_accuracy(
ln_mlp_ref, bs, dtype, config, delay_wgrad_compute=False
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
def _test_grouped_linear_accuracy( def _test_grouped_linear_accuracy(
block, block,
num_gemms, num_gemms,
......
...@@ -29,7 +29,6 @@ from transformer_engine.pytorch.quantization import ( ...@@ -29,7 +29,6 @@ from transformer_engine.pytorch.quantization import (
) )
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling
import transformer_engine_torch as tex
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
......
...@@ -525,6 +525,7 @@ def test_sanity_grouped_linear( ...@@ -525,6 +525,7 @@ def test_sanity_grouped_linear(
@pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("microbatching", all_boolean) @pytest.mark.parametrize("microbatching", all_boolean)
@pytest.mark.parametrize("checkpoint", all_boolean)
def test_sanity_layernorm_mlp( def test_sanity_layernorm_mlp(
dtype, dtype,
fp8_recipe, fp8_recipe,
...@@ -535,6 +536,7 @@ def test_sanity_layernorm_mlp( ...@@ -535,6 +536,7 @@ def test_sanity_layernorm_mlp(
activation, activation,
normalization, normalization,
microbatching, microbatching,
checkpoint,
): ):
config = model_configs[model] config = model_configs[model]
...@@ -559,6 +561,7 @@ def test_sanity_layernorm_mlp( ...@@ -559,6 +561,7 @@ def test_sanity_layernorm_mlp(
normalization=normalization, normalization=normalization,
params_dtype=dtype, params_dtype=dtype,
device="cuda", device="cuda",
checkpoint=checkpoint,
) )
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
......
...@@ -56,6 +56,8 @@ from ..distributed import ( ...@@ -56,6 +56,8 @@ from ..distributed import (
use_reentrant_activation_recompute, use_reentrant_activation_recompute,
in_fp8_activation_recompute_phase, in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors, _fsdp_scatter_tensors,
_get_cuda_rng_state,
_set_cuda_rng_state,
) )
from ..constants import dist_group_type from ..constants import dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
...@@ -165,7 +167,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -165,7 +167,7 @@ class _LayerNormMLP(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward( def _forward(
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
ln_weight: torch.Tensor, ln_weight: torch.Tensor,
...@@ -226,9 +228,103 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -226,9 +228,103 @@ class _LayerNormMLP(torch.autograd.Function):
module, module,
skip_fp8_weight_update, skip_fp8_weight_update,
symmetric_ar_type, symmetric_ar_type,
checkpoint,
debug, debug,
recompute_for_bwd,
) = non_tensor_args ) = non_tensor_args
# if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take
if is_grad_enabled and not recompute_for_bwd:
ctx.checkpoint = checkpoint
if checkpoint:
# save the state of autocast and quantizers for recomputation
ctx.autocast_state = (
FP8GlobalStateManager.get_autocast_state()
) # to restore autocast state during recomputation
if (
fp8
and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__
== "DelayedScaling"
): # only applicable for delayed scaling
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(
module.fp8_meta
) # to restore quantizers during recomputation
# save the rng states
ctx.cpu_rng_state = torch.get_rng_state()
ctx.cuda_rng_state = _get_cuda_rng_state()
# whether to save activations regularly, or save inputs for recomputation in bwd
save_for_checkpoint = checkpoint and is_grad_enabled and not recompute_for_bwd
# whether we are in the forward stage, or recomputing in the bwd stage (false if not checkpointing)
is_recomputation = checkpoint and is_grad_enabled and recompute_for_bwd
# save the initial state for recomputation by bwd
if save_for_checkpoint:
# save tensors
tensors_to_save, tensor_objects = prepare_for_saving(
inp,
ln_weight,
ln_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias,
)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
ctx.other_args = {
"eps": eps,
"is_first_microbatch": is_first_microbatch,
"fp8": fp8,
"fp8_calibration": fp8_calibration,
"wgrad_store": wgrad_store,
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"fc1_input_quantizer": fc1_input_quantizer,
"fc1_weight_quantizer": fc1_weight_quantizer,
"fc1_output_quantizer": fc1_output_quantizer,
"fc1_grad_input_quantizer": fc1_grad_input_quantizer,
"fc1_grad_weight_quantizer": fc1_grad_weight_quantizer,
"fc1_grad_output_quantizer": fc1_grad_output_quantizer,
"fc2_input_quantizer": fc2_input_quantizer,
"fc2_weight_quantizer": fc2_weight_quantizer,
"fc2_output_quantizer": fc2_output_quantizer,
"fc2_grad_input_quantizer": fc2_grad_input_quantizer,
"fc2_grad_weight_quantizer": fc2_grad_weight_quantizer,
"fc2_grad_output_quantizer": fc2_grad_output_quantizer,
"cpu_offloading": cpu_offloading,
"tp_group": tp_group,
"tp_size": tp_size,
"sequence_parallel": sequence_parallel,
"tensor_parallel": tensor_parallel,
"activation_dtype": activation_dtype,
"return_layernorm_output": return_layernorm_output,
"return_layernorm_output_gathered": return_layernorm_output_gathered,
"bias_gelu_fusion": bias_gelu_fusion,
"set_parallel_mode": set_parallel_mode,
"is_grad_enabled": is_grad_enabled,
"fwd_ln_sm_margin": fwd_ln_sm_margin,
"bwd_ln_sm_margin": bwd_ln_sm_margin,
"zero_centered_gamma": zero_centered_gamma,
"activation": activation,
"activation_params": activation_params,
"normalization": normalization,
"ub_overlap_ag": ub_overlap_ag,
"ub_overlap_rs": ub_overlap_rs,
"ub_overlap_rs_dgrad": ub_overlap_rs_dgrad,
"ub_bulk_wgrad": ub_bulk_wgrad,
"ub_bulk_dgrad": ub_bulk_dgrad,
"gemm_gelu_fusion": gemm_gelu_fusion,
"fsdp_group": fsdp_group,
"module": module,
"skip_fp8_weight_update": skip_fp8_weight_update,
"symmetric_ar_type": symmetric_ar_type,
"checkpoint": checkpoint,
"debug": debug,
"recompute_for_bwd": True, # set this to true for recomputation phase
}
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features, inp_shape = ln_weight.numel(), inp.shape in_features, inp_shape = ln_weight.numel(), inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible" assert inp_shape[-1] == in_features, "GEMM not possible"
...@@ -250,7 +346,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -250,7 +346,14 @@ class _LayerNormMLP(torch.autograd.Function):
start_offload(inputmat) start_offload(inputmat)
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
# bwd needs fc1 input when grad is enabled, fc1 needs grad, and either
# 1) no checkpointing
# or 2) doing the recomputation with checkpointing
backwards_needs_fc1_input = fc1_weight.requires_grad and (
(is_grad_enabled and not checkpoint) or is_recomputation
)
device = inp.device device = inp.device
# Configure Userbuffers communication (comm+GEMM overlap) # Configure Userbuffers communication (comm+GEMM overlap)
...@@ -308,7 +411,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -308,7 +411,9 @@ class _LayerNormMLP(torch.autograd.Function):
zero_centered_gamma, zero_centered_gamma,
) )
ln_out_return = None ln_out_return = None
if return_layernorm_output or return_layernorm_output_gathered:
# do not return layernorm output unless 1) no checkpointing or 2) checkpointing but not recomputing
if (return_layernorm_output or return_layernorm_output_gathered) and not is_recomputation:
ln_out_return = ln_out ln_out_return = ln_out
# Prepare GEMM input # Prepare GEMM input
...@@ -316,7 +421,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -316,7 +421,9 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total = None ln_out_total = None
ub_obj_lnout = None ub_obj_lnout = None
if sequence_parallel: if sequence_parallel:
if return_layernorm_output_gathered:
# do not return ln output if checkpointing and in recomputation, not necessary
if return_layernorm_output_gathered and not is_recomputation:
# Perform all-gather in high precision if gathered # Perform all-gather in high precision if gathered
# norm output will be returned # norm output will be returned
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
...@@ -459,7 +566,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -459,7 +566,12 @@ class _LayerNormMLP(torch.autograd.Function):
# ------------------------------------------------------ # ------------------------------------------------------
# Deallocate FC1 GEMM input tensor if no longer needed # Deallocate FC1 GEMM input tensor if no longer needed
if not is_grad_enabled and (ln_out_total is not ln_out_return): # first part of if statement means that we only clear ln_out_total if
# 1) checkpointing and not recomputing (in the forward stage, not bwd recompute stage)
# 2) not checkpointing and grad disabled
if ((checkpoint and not is_recomputation) or not is_grad_enabled) and (
ln_out_total is not ln_out_return
):
clear_tensor_data(ln_out_total) clear_tensor_data(ln_out_total)
# ACTIVATION - sometimes activation is fused with the GEMM above. # ACTIVATION - sometimes activation is fused with the GEMM above.
...@@ -497,89 +609,88 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -497,89 +609,88 @@ class _LayerNormMLP(torch.autograd.Function):
else: else:
act_out = activation_func(fc1_out, fc2_input_quantizer, **act_params) act_out = activation_func(fc1_out, fc2_input_quantizer, **act_params)
if not is_grad_enabled:
clear_tensor_data(fc1_out)
if not fp8 and fp8_calibration: if not fp8 and fp8_calibration:
if fc2_input_quantizer is not None: if fc2_input_quantizer is not None:
fc2_input_quantizer.calibrate(act_out) fc2_input_quantizer.calibrate(act_out)
if fc2_weight_quantizer is not None:
fc2_weight_quantizer.calibrate(fc2_weight)
# Configure Userbuffers reduce-scatter if needed
ub_obj_fc2out = None
reduce_scatter_out = None
if ub_overlap_rs:
ub_obj_fc2out = get_ub("fc2_fprop", fp8)
dim_size = list(act_out.size())
dim_size[0] //= tp_world_size
dim_size[-1] = fc2_weight.size(0)
reduce_scatter_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
# ------------------------------------------------------ # we want to skip fc2 computation if we are checkpointing and recomputing,
# FC2 GEMM # otherwise we compute fc2
# ------------------------------------------------------ if not (is_recomputation and checkpoint):
gemm_out, *_, reduce_scatter_out = general_gemm(
fc2_weight_final,
act_out,
out_dtype=activation_dtype,
bias=fc2_bias,
quantization_params=fc2_output_quantizer,
use_split_accumulator=use_split_accumulator,
ub=ub_obj_fc2out,
ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None,
extra_output=reduce_scatter_out,
)
# ------------------------------------------------------
# Finished FC2 GEMM...
# ------------------------------------------------------
# Deallocate tensors if no longer needed # if we get to this point, we know this is not bwd recomputation
if not is_grad_enabled: # so we must be in the fwd
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) # now is_grad_enabled can be true or false
# if false, can safely delete
# Prepare output tensor # if true, we can only delete if checkpoint is true, since we will recompute anyways,
# Note: Perform tensor-parallel communication if needed # otherwise, checkpoint is false, so cant delete
fc2_out = None if (
if ub_overlap_rs: checkpoint or not is_grad_enabled
fc2_out = reduce_scatter_out ): # we can safely get rid of these if this is the case
elif set_parallel_mode and sequence_parallel: clear_tensor_data(fc1_out)
fc2_out, _ = reduce_scatter_along_first_dim(gemm_out, tp_group)
elif set_parallel_mode and tensor_parallel:
if symmetric_ar_type is not None:
fc2_out, _ = symmetric_all_reduce(
gemm_out, tp_group, all_reduce_type=symmetric_ar_type
)
else:
fc2_out, _ = allreduce(gemm_out, tp_group)
else:
fc2_out = gemm_out
fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1])
# Cache state for backward pass
if is_grad_enabled:
if cpu_offloading:
mark_activation_offload(
inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out
)
# Scatter intermediate/activation tensors saved for the backward pass if not fp8 and fp8_calibration:
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves if fc2_weight_quantizer is not None:
ctx.fsdp_group = fsdp_group fc2_weight_quantizer.calibrate(fc2_weight)
ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group, # Configure Userbuffers reduce-scatter if needed
mu, ub_obj_fc2out = None
rsigma, reduce_scatter_out = None
ln_out, if ub_overlap_rs:
fc1_out_without_bias if bias_gelu_fusion else fc1_out, ub_obj_fc2out = get_ub("fc2_fprop", fp8)
dim_size = list(act_out.size())
dim_size[0] //= tp_world_size
dim_size[-1] = fc2_weight.size(0)
reduce_scatter_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
# ------------------------------------------------------
# FC2 GEMM
# ------------------------------------------------------
gemm_out, *_, reduce_scatter_out = general_gemm(
fc2_weight_final,
act_out, act_out,
fc1_weight_final if fp8 and not isinstance(fc1_weight, Float8Tensor) else None, out_dtype=activation_dtype,
fc2_weight_final if fp8 and not isinstance(fc2_weight, Float8Tensor) else None, bias=fc2_bias,
quantization_params=fc2_output_quantizer,
use_split_accumulator=use_split_accumulator,
ub=ub_obj_fc2out,
ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None,
extra_output=reduce_scatter_out,
) )
# ------------------------------------------------------
# Finished FC2 GEMM...
# ------------------------------------------------------
# Deallocate tensors if no longer needed, again, can safely deallocate
if checkpoint or not is_grad_enabled: # same logic as last clear_tensor_data block
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
# Prepare output tensor
# Note: Perform tensor-parallel communication if needed
fc2_out = None
if ub_overlap_rs:
fc2_out = reduce_scatter_out
elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(gemm_out, tp_group)
elif set_parallel_mode and tensor_parallel:
if symmetric_ar_type is not None:
fc2_out, _ = symmetric_all_reduce(
gemm_out, tp_group, all_reduce_type=symmetric_ar_type
)
else:
fc2_out, _ = allreduce(gemm_out, tp_group)
else:
fc2_out = gemm_out
fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1])
# now saving stuff for bwd:
# if we are using checkpointing, this information will be saved in the bwd recomputation stage, so can skip it in fwd
# if we are not checkpointing, then we must save this if grad is enabled
if is_grad_enabled and not save_for_checkpoint:
ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc1_weight_quantizer = fc1_weight_quantizer
ctx.fc2_weight_quantizer = fc2_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer
if not fc1_weight.requires_grad: if not fc1_weight.requires_grad:
if not return_layernorm_output: if not return_layernorm_output:
clear_tensor_data(ln_out) clear_tensor_data(ln_out)
...@@ -588,34 +699,69 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -588,34 +699,69 @@ class _LayerNormMLP(torch.autograd.Function):
clear_tensor_data(act_out) clear_tensor_data(act_out)
act_out = None act_out = None
if cpu_offloading: if not checkpoint: # regular path, no selective activation checkpointing
mark_not_offload(
if cpu_offloading:
mark_activation_offload(
inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out
)
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves
ctx.fsdp_group = fsdp_group
ctx.fsdp_shapes = (
_fsdp_scatter_tensors( # again, ony relevant if we have activations to save
fsdp_group,
mu,
rsigma,
ln_out,
fc1_out_without_bias if bias_gelu_fusion else fc1_out,
act_out,
(
fc1_weight_final
if fp8 and not isinstance(fc1_weight, Float8Tensor)
else None
),
(
fc2_weight_final
if fp8 and not isinstance(fc2_weight, Float8Tensor)
else None
),
)
)
if cpu_offloading:
mark_not_offload(
ln_weight,
ln_bias,
fc1_weight_final,
fc1_weight,
fc1_bias,
fc2_weight_final,
fc2_weight,
fc2_bias,
)
tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
ln_weight, ln_weight,
ln_bias, ln_out,
fc1_weight_final, fc1_weight_final,
fc1_weight, fc1_weight,
fc1_bias, fc1_bias,
fc1_out,
fc1_out_without_bias,
act_out,
fc2_weight_final, fc2_weight_final,
fc2_weight, fc2_weight,
fc2_bias, fc2_bias,
mu,
rsigma,
) )
tensors_to_save, tensor_objects = prepare_for_saving( ctx.save_for_backward(*tensors_to_save)
inputmat, ctx.tensor_objects = tensor_objects
ln_weight,
ln_out,
fc1_weight_final,
fc1_weight,
fc1_bias,
fc1_out,
fc1_out_without_bias,
act_out,
fc2_weight_final,
fc2_weight,
fc2_bias,
mu,
rsigma,
)
if fuse_wgrad_accumulation: if fuse_wgrad_accumulation:
# This check is needed to ensure that main_grad is not created # This check is needed to ensure that main_grad is not created
...@@ -633,9 +779,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -633,9 +779,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fc1_main_grad_func = lambda: fc1_weight.main_grad ctx.fc1_main_grad_func = lambda: fc1_weight.main_grad
ctx.fc2_main_grad_func = lambda: fc2_weight.main_grad ctx.fc2_main_grad_func = lambda: fc2_weight.main_grad
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer
ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer
...@@ -690,11 +833,30 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -690,11 +833,30 @@ class _LayerNormMLP(torch.autograd.Function):
): ):
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase(): if in_fp8_activation_recompute_phase() or is_recomputation:
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.wgrad_store = wgrad_store ctx.wgrad_store = wgrad_store
if is_recomputation: # return the recomputed tensors
return (
ctx,
inputmat,
ln_weight,
ln_out,
fc1_weight_final,
fc1_weight,
fc1_bias,
fc1_out,
fc1_out_without_bias,
act_out,
fc2_weight_final,
fc2_weight,
fc2_bias,
mu,
rsigma,
)
# we only get to this point if we are not recomputing for bwd, since that would have returned in the block above
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
shape = list(inp_shape) shape = list(inp_shape)
...@@ -703,14 +865,101 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -703,14 +865,101 @@ class _LayerNormMLP(torch.autograd.Function):
return fc2_out, ln_out_return.view(inp_shape) return fc2_out, ln_out_return.view(inp_shape)
return fc2_out return fc2_out
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
fc1_weight: torch.Tensor,
fc1_bias: torch.Tensor,
fc2_weight: torch.Tensor,
fc2_bias: torch.Tensor,
non_tensor_args: Tuple,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
# add recompute_for_bwd
non_tensor_args += (False,)
return _LayerNormMLP._forward(
ctx,
inp,
ln_weight,
ln_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias,
non_tensor_args,
)
@staticmethod
def _recompute(ctx):
# pylint: disable=missing-function-docstring
saved_tensors = ctx.saved_tensors
tensors = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
if ctx.checkpoint: # do recomputation from the original args
# backward is not in autocast context, so we set the state here
# we also have to set the quantizer states to what they were before the forward pass (only relevant for DelayedScaling recipe)
final_autocast_state = (
FP8GlobalStateManager.get_autocast_state()
) # get current autocast state
FP8GlobalStateManager.set_autocast_state(ctx.autocast_state) # set old autocast state
if (
ctx.other_args["fp8"]
and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__ == "DelayedScaling"
): # only applicable for delayed scaling
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(
ctx.other_args["module"].fp8_meta
) # set old quantizer state
# get current rng state
final_cpu_rng_state = torch.get_rng_state()
final_cuda_rng_state = _get_cuda_rng_state()
# set rng state for fwd
torch.set_rng_state(ctx.cpu_rng_state)
_set_cuda_rng_state(ctx.cuda_rng_state)
out = _LayerNormMLP._forward( # recompute
ctx,
*tensors,
tuple(ctx.other_args.values()),
)
FP8GlobalStateManager.set_autocast_state(final_autocast_state) # restore autocast state
if (
ctx.other_args["fp8"]
and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__ == "DelayedScaling"
):
FP8GlobalStateManager.restore_fp8_meta_tensors(
ctx.other_args["module"].fp8_meta
) # restore quantizers
# set rng state for fwd
torch.set_rng_state(final_cpu_rng_state)
_set_cuda_rng_state(final_cuda_rng_state)
return out
# load from saved (return ctx is just because the other branch does too)
return tuple([ctx] + tensors)
@staticmethod @staticmethod
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
with get_nvtx_range_context("_LayerNormMLP_backward"): with get_nvtx_range_context("_LayerNormMLP_backward"):
saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
ctx,
inputmat, inputmat,
ln_weight, ln_weight,
ln_out, ln_out,
...@@ -725,11 +974,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -725,11 +974,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_bias, fc2_bias,
mu, mu,
rsigma, rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors) ) = _LayerNormMLP._recompute(ctx)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
# Since main_grad can be modified inplace, it should not be a part of saved_tensors # Since main_grad can be modified inplace, it should not be a part of saved_tensors
fc1_weight_main_grad = ( fc1_weight_main_grad = (
...@@ -1512,6 +1757,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1512,6 +1757,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
This can help in latency bound communication situations. This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
is used. is used.
checkpoint: bool, default = False
whether to use selective activation checkpointing, where activations are not saved for bwd,
and instead are recomputed (skipping fc2, as it is not needed for backward). Trades compute
for memory. default is false, in which activations are saved in fwd. not supported for onnx forward
""" """
def __init__( def __init__(
...@@ -1547,6 +1796,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1547,6 +1796,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
delay_wgrad_compute: bool = False, delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None, symmetric_ar_type: Optional[str] = None,
checkpoint: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1567,6 +1817,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1567,6 +1817,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.set_parallel_mode = set_parallel_mode self.set_parallel_mode = set_parallel_mode
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type self.symmetric_ar_type = symmetric_ar_type
self.checkpoint = checkpoint
# GEMM-GELU fusion is currently only supported with split GEMM-AG overlap # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap
self.gemm_gelu_fusion = ( self.gemm_gelu_fusion = (
...@@ -1896,6 +2147,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1896,6 +2147,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self, self,
skip_fp8_weight_update, skip_fp8_weight_update,
self.symmetric_ar_type, self.symmetric_ar_type,
self.checkpoint,
debug, debug,
) )
out = fwd_fn( out = fwd_fn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment