Unverified Commit 05bfa3f8 authored by Jaime's avatar Jaime Committed by GitHub
Browse files

[PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP with...


[PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP with checkpoint flag (#2311)

* custom tests for selective activation checkpointing for layernorm mlp
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* add selective layernorm mlp to te.pytorch
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* update test and fix SLNMLP bug
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* implement slnmlp
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* fix tests pointed out by greptile app bot, still pass
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* minor formatting change in tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarJaime <102792198+jaimec00@users.noreply.github.com>
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* remove duplicate import in test/pytorch/selective_layernorm_mlp/test_recipe.py
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* clean up tests, remove unused imports
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* remove unused paths in test_deffered_init
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* fix issue with zero_centered_gamma in test_numerics reference implementation
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* clean up tests
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* make comparison.py more extensive, cleaner output
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* fix small typo in tests/pytorch/selective_layernorm_mlp/compare.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarJaime <102792198+jaimec00@users.noreply.github.com>
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* fix typo by grepbot in compare.py
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* make selectiuve activation checkpointing optional in slnmlp via checkpoint flag
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* add comments to clarify logic
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* add checkpoint param to pytests, change compare.py to compare checkppoint=False vs checkpoint=True, skip cuda graph tests for checkpoint=True
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* refactor tests to call modified LayerNormMLP
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* refactor to implement selective activation checkpointing directly into LayerNormMLP, also fix bug to reach cleanup logic in fwd
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix skip explanation for cuda_graphs.py
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* make _recompute deal with lists instead of tuples
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix MOST cuda graph failures by initializing identical quantizers during fwd. Float8CurrentScaling with bf16 and fp16 still fail with checkpointing
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix cuda graphs issue, all tests pass now
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix small logic bugs, clean up
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* integrate tests into main testing scripts
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* incorporate rng state tracking in checkpointing
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* clean up tests
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* fix return type mismatches
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* remove checkpoint test from test_recipe, add sperate test in test_numerics
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor typo fix
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarJaime <102792198+jaimec00@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* clear up assertions in tests/pytorch/layernorm_mlp/test_selective_activation_checkpoint.py
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add license and copyright info
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* fix lint issues in layernorm_mlp
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* fix cpu_offload_v1 error
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* possibly fix recomputation in cuda graph bug
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* skip cuda graphs test for SLNMLP with SM>=10.0 and using delayed scaling
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix typo for setting IS_FIRST_FP8_MODULE
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>

---------
Signed-off-by: default avatarJaime Cardenas <jaime@evolutionaryscale.ai>
Signed-off-by: default avatarJaime <102792198+jaimec00@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 30c0120b
......@@ -1030,6 +1030,7 @@ def test_layernorm_mlp():
{"return_bias": True},
{"return_layernorm_output": True},
{"delay_wgrad_compute": True},
{"checkpoint": True},
]
for kwargs in kwargs_list:
......
......@@ -13,7 +13,7 @@ import transformer_engine.pytorch as te
"""
Distributed numerics tests
These tests test the numerical corectness of the TransformerEngine layers.
These tests test the numerical correctness of the TransformerEngine layers.
Tests are parametrized by the layer and fp8 precision.
One test consists of running multiple configurations from file run_numerics.py
Such design is due to the fact the initialization of one test is long
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
from transformer_engine.pytorch import LayerNormMLP
import pytest
torch.manual_seed(1234)
device = torch.device("cuda")
class _Sequential(torch.nn.Sequential):
"""Sequential model that forwards keyword arguments to modules"""
def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
x = input_
for module in self:
x = module(x, **kwargs)
return x
class ModelConfig:
def __init__(
self,
hidden_size: int = 128,
ffn_hidden_size: int = 512,
layers: int = 1,
):
self._hidden_size = hidden_size
self._ffn_hidden_size = ffn_hidden_size
self._layers = layers
def build(self):
ln_list, sln_list = [], []
for _ in range(self._layers):
ln = LayerNormMLP(self._hidden_size, self._ffn_hidden_size, checkpoint=False).to(device)
sln = LayerNormMLP(self._hidden_size, self._ffn_hidden_size, checkpoint=True).to(device)
with torch.no_grad():
sln.layer_norm_weight = torch.nn.Parameter(ln.layer_norm_weight.clone())
sln.layer_norm_bias = torch.nn.Parameter(ln.layer_norm_bias.clone())
sln.fc1_weight = torch.nn.Parameter(ln.fc1_weight.clone())
sln.fc2_weight = torch.nn.Parameter(ln.fc2_weight.clone())
sln.fc1_bias = torch.nn.Parameter(ln.fc1_bias.clone())
sln.fc2_bias = torch.nn.Parameter(ln.fc2_bias.clone())
ln_list.append(ln)
sln_list.append(sln)
ln_model = _Sequential(*ln_list)
sln_model = _Sequential(*sln_list)
return ln_model, sln_model
config = {
"small": ModelConfig(128, 512, 12),
"medium": ModelConfig(512, 2048, 12),
"large": ModelConfig(1024, 4096, 12),
"huge": ModelConfig(2048, 8192, 12),
}
seq_sizes = [2**7, 2**10, 2**14, 2**16]
def _warmup(model, tensor):
for _ in range(3):
model(tensor).sum().backward()
def _run_fwd(model, tensor):
torch.cuda.reset_peak_memory_stats(device)
start_time, end_time = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
enable_timing=True
)
torch.cuda.synchronize()
start_mem = torch.cuda.memory_allocated(device)
start_time.record()
out = model(tensor)
end_time.record()
end_time.synchronize()
elapsed = start_time.elapsed_time(end_time)
peak_mem = torch.cuda.max_memory_allocated(device)
mem = float(peak_mem - start_mem)
return out, elapsed, mem
def _run_bwd(model, out):
model.zero_grad(set_to_none=False)
loss = out.sum()
torch.cuda.reset_peak_memory_stats(device)
start_time, end_time = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
enable_timing=True
)
torch.cuda.synchronize()
start_mem = torch.cuda.memory_allocated(device)
start_time.record()
loss.backward()
end_time.record()
end_time.synchronize()
elapsed = start_time.elapsed_time(end_time)
peak_mem = torch.cuda.max_memory_allocated(device)
mem = float(peak_mem - start_mem)
param_grads = _collect_param_grads(model)
return param_grads, elapsed, mem
def _max_diff(ref, other):
"""Return max absolute difference between two tensors or collections."""
if ref is None or other is None:
return 0.0
if isinstance(ref, (list, tuple)):
diffs = [_max_diff(r, o) for r, o in zip(ref, other)]
return max(diffs) if diffs else 0.0
return torch.max(torch.abs(ref.detach() - other.detach())).item()
def _collect_param_grads(model):
grads = {}
for name, param in model.named_parameters():
if param.grad is None:
continue
key = _param_key(name)
if key is not None:
grads[key] = param.grad.detach().clone()
return grads
def _param_key(name):
return name.split(".")[-1]
@pytest.mark.parametrize("size", config.keys())
@pytest.mark.parametrize("seq_size", seq_sizes)
def test_selective_activation_checkpoint(size, seq_size):
ln_model, sln_model = config[size].build()
data = torch.randn((seq_size, config[size]._hidden_size), device=device)
_warmup(ln_model, data)
ln_fwd_out, ln_fwd_time, ln_fwd_mem = _run_fwd(ln_model, data)
ln_grads, ln_bwd_time, ln_bwd_mem = _run_bwd(ln_model, ln_fwd_out)
_warmup(sln_model, data)
sln_fwd_out, sln_fwd_time, sln_fwd_mem = _run_fwd(sln_model, data)
sln_grads, sln_bwd_time, sln_bwd_mem = _run_bwd(sln_model, sln_fwd_out)
assert ln_fwd_mem > 6 * sln_fwd_mem, (
"selective activation checkpointing does not reduce forward memory by 6X, only by"
f" {ln_fwd_mem/sln_fwd_mem}!"
)
assert ln_bwd_time < sln_bwd_time, (
"selective activation activation checkpointing backward pass is NOT slower than native!"
f" got Native LayerNormMLP Backward Time: {ln_bwd_time} ms and Selective Activation"
f" Checkpointed LayerNormMLP Backward Time: {sln_bwd_time} ms"
)
diff = _max_diff(ln_fwd_out, sln_fwd_out)
assert diff == 0.0, f"outputs are not equal! maximum difference {diff}"
for key in [
"layer_norm_weight",
"layer_norm_bias",
"fc1_weight",
"fc1_bias",
"fc2_weight",
"fc2_bias",
]:
diff = _max_diff(ln_grads[key], sln_grads[key])
assert diff == 0.0, f"gradients for {key} are not equal! maximum difference: {diff}"
......@@ -190,7 +190,8 @@ _test_cuda_graphs_modules: List[str] = [
# creating TMA descriptor for MXFP8 quantization.
"linear",
"transformer",
"layernorm_mlp",
"layernorm_mlp_nocheckpoint",
"layernorm_mlp_checkpoint",
"layernorm_linear",
"mha",
"linear_op",
......@@ -232,12 +233,23 @@ def _test_cuda_graphs(
)
for _ in range(num_layers)
]
elif module == "layernorm_mlp":
elif module == "layernorm_mlp_nocheckpoint":
modules = [
LayerNormMLP(
model_config.hidden_size,
model_config.hidden_size,
params_dtype=dtype,
checkpoint=False,
)
for _ in range(num_layers)
]
elif module == "layernorm_mlp_checkpoint":
modules = [
LayerNormMLP(
model_config.hidden_size,
model_config.hidden_size,
params_dtype=dtype,
checkpoint=True,
)
for _ in range(num_layers)
]
......@@ -376,6 +388,17 @@ def test_make_graphed_callables(
)
if fp8_params:
pytest.skip("NVFP4 params not supported")
if (
fp8
and fp8_recipe.delayed()
and torch.cuda.get_device_capability() >= (10, 0)
and module == "layernorm_mlp_checkpoint"
):
pytest.skip(
"CUDA graphs not supported for LayerNormMLP "
"with checkpoint=True, SM>=10, "
"and DelayedScaling recipe"
)
# Run model with different CUDA graph settings.
model_config = model_configs[model_config]
......@@ -402,7 +425,8 @@ def test_make_graphed_callables(
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
"transformer",
"layernorm_mlp",
"layernorm_mlp_nocheckpoint",
"layernorm_mlp_checkpoint",
"layernorm_linear",
"linear",
"mha",
......
......@@ -185,7 +185,7 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
raise ValueError(f"Unsuppored dtype ({dtype})")
raise ValueError(f"Unsupported dtype ({dtype})")
def assert_allclose(
......@@ -1363,7 +1363,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe)
te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe)
# Shoule be bit-wise match
# Should be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
......@@ -1696,7 +1696,11 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_mlp_accuracy_delay_wgrad_compute(
dtype, bs, model, bias, fuse_wgrad_accumulation
dtype,
bs,
model,
bias,
fuse_wgrad_accumulation,
):
config = model_configs[model]
......@@ -1747,6 +1751,58 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", [2])
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_mlp_accuracy_checkpoint(
dtype,
bs,
model,
bias,
):
config = model_configs[model]
ln_mlp = LayerNormMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
bias=bias,
params_dtype=dtype,
device="cuda",
checkpoint=True,
).eval()
ln_mlp_ref = LayerNormMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
bias=bias,
params_dtype=dtype,
device="cuda",
checkpoint=False,
).eval()
# Share params
with torch.no_grad():
ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
if bias:
ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone())
ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone())
te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=False)
te_outputs_ref = _test_granular_accuracy(
ln_mlp_ref, bs, dtype, config, delay_wgrad_compute=False
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
def _test_grouped_linear_accuracy(
block,
num_gemms,
......
......@@ -29,7 +29,6 @@ from transformer_engine.pytorch.quantization import (
)
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling
import transformer_engine_torch as tex
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
......
......@@ -525,6 +525,7 @@ def test_sanity_grouped_linear(
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("microbatching", all_boolean)
@pytest.mark.parametrize("checkpoint", all_boolean)
def test_sanity_layernorm_mlp(
dtype,
fp8_recipe,
......@@ -535,6 +536,7 @@ def test_sanity_layernorm_mlp(
activation,
normalization,
microbatching,
checkpoint,
):
config = model_configs[model]
......@@ -559,6 +561,7 @@ def test_sanity_layernorm_mlp(
normalization=normalization,
params_dtype=dtype,
device="cuda",
checkpoint=checkpoint,
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment