Commit c1a1c04e authored by wenjh's avatar wenjh
Browse files

Merge nv_main(2.10) to main


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents e698a0a7 66aed3ae
...@@ -45,7 +45,8 @@ from transformer_engine.pytorch.utils import ( ...@@ -45,7 +45,8 @@ from transformer_engine.pytorch.utils import (
) )
from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.utils import get_cudnn_version
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.quantized_tensor import ( from transformer_engine.pytorch.quantized_tensor import (
Quantizer,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
...@@ -60,8 +61,16 @@ from utils import ( ...@@ -60,8 +61,16 @@ from utils import (
get_available_attention_backends, get_available_attention_backends,
) )
# Check if hardware supports FP8 # Check if hardware supports FP8 attention.
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
fp8_attn_available, reason_for_no_fp8_attn = fp8_available, reason_for_no_fp8
device_compute_capability = get_device_compute_capability()
if fp8_available and (device_compute_capability < (9, 0) or device_compute_capability >= (12, 0)):
fp8_attn_available = False
reason_for_no_fp8_attn = (
"FP8 attention is not supported for compute capability ="
f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}"
)
# Reset RNG seed and states # Reset RNG seed and states
seed = 1234 seed = 1234
...@@ -130,6 +139,11 @@ def test_dot_product_attention( ...@@ -130,6 +139,11 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa: if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2] config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
config.attn_mask_type = (
"padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
)
# Get backends # Get backends
is_training = True is_training = True
...@@ -171,7 +185,7 @@ def test_dot_product_attention( ...@@ -171,7 +185,7 @@ def test_dot_product_attention(
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
if unfused_attn_supported: if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( unfused_attn_fwd, unfused_max_logit, unfused_attn_bwd = _run_dot_product_attention(
dtype, dtype,
config, config,
"UnfusedDotProductAttention", "UnfusedDotProductAttention",
...@@ -185,7 +199,7 @@ def test_dot_product_attention( ...@@ -185,7 +199,7 @@ def test_dot_product_attention(
# FusedAttention backend # FusedAttention backend
if fused_attn_supported: if fused_attn_supported:
if len(fused_attn_backends) == 1: if len(fused_attn_backends) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_max_logit, fused_attn_bwd = _run_dot_product_attention(
dtype, dtype,
config, config,
"FusedAttention", "FusedAttention",
...@@ -197,7 +211,7 @@ def test_dot_product_attention( ...@@ -197,7 +211,7 @@ def test_dot_product_attention(
) )
if len(fused_attn_backends) == 2: if len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention(
dtype, dtype,
config, config,
"FusedAttention", "FusedAttention",
...@@ -208,7 +222,7 @@ def test_dot_product_attention( ...@@ -208,7 +222,7 @@ def test_dot_product_attention(
is_training, is_training,
) )
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention(
dtype, dtype,
config, config,
"FusedAttention", "FusedAttention",
...@@ -221,7 +235,7 @@ def test_dot_product_attention( ...@@ -221,7 +235,7 @@ def test_dot_product_attention(
# FlashAttention backend # FlashAttention backend
if flash_attn_supported: if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( flash_attn_fwd, _, flash_attn_bwd = _run_dot_product_attention(
dtype, dtype,
config, config,
"FlashAttention", "FlashAttention",
...@@ -242,6 +256,8 @@ def test_dot_product_attention( ...@@ -242,6 +256,8 @@ def test_dot_product_attention(
if unfused_attn_supported and fused_attn_supported: if unfused_attn_supported and fused_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs fused attn") logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
if config.return_max_logit:
torch.testing.assert_close(fused_max_logit, unfused_max_logit, **tols)
for i, _ in enumerate(unfused_attn_bwd): for i, _ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported: if fused_attn_supported and flash_attn_supported:
...@@ -265,6 +281,33 @@ def test_dpa_checkpoint(dtype, model_configs, model): ...@@ -265,6 +281,33 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
model_configs_max_logit = {
# test: ModelConfig(b, sq, hq, dqk)
"max_logit_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
"max_logit_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
"max_logit_3": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
"max_logit_4": ModelConfig(
8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias"
),
"max_logit_5": ModelConfig(
8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0)
),
"max_logit_6": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_max_logit])
@pytest.mark.parametrize("model", model_configs_max_logit.keys())
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"])
def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with checkpointing"""
config = model_configs[model]
config.return_max_logit = True
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
model_configs_softmax = { model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk) # test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8), "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
...@@ -962,6 +1005,8 @@ def _run_dot_product_attention( ...@@ -962,6 +1005,8 @@ def _run_dot_product_attention(
layout = layout.replace("d", "dqk") layout = layout.replace("d", "dqk")
tensor_shape = [dim_to_num[j] for j in layout.split("_")] tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda") tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
# tensor: with padding tokens
# tensor_orig: without padding tokens
tensor_orig = tensor tensor_orig = tensor
if qkv_format == "thd" and pad_between_seqs: if qkv_format == "thd" and pad_between_seqs:
tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
...@@ -1071,6 +1116,7 @@ def _run_dot_product_attention( ...@@ -1071,6 +1116,7 @@ def _run_dot_product_attention(
layer_number=1, layer_number=1,
attention_type=config.attn_type, attention_type=config.attn_type,
softmax_type=config.softmax_type, softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit,
).to(dtype=dtype, device="cuda") ).to(dtype=dtype, device="cuda")
if not is_training: if not is_training:
block = block.eval() block = block.eval()
...@@ -1108,16 +1154,21 @@ def _run_dot_product_attention( ...@@ -1108,16 +1154,21 @@ def _run_dot_product_attention(
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
fast_zero_fill=True, fast_zero_fill=True,
) )
max_logit = None
if config.return_max_logit:
out, max_logit = out
if is_training: if is_training:
out.backward(d_out) out.backward(d_out)
d_softmax_offset = None d_softmax_offset = None
if is_training and config.softmax_type != "vanilla": if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad d_softmax_offset = block.softmax_offset.grad
if backend in ["FlashAttention", "UnfusedDotProductAttention"]: if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if is_training: if is_training:
return out, (q.grad, k.grad, v.grad, d_softmax_offset) return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
else: else:
return out, (None, None, None, d_softmax_offset) return out, max_logit, (None, None, None, d_softmax_offset)
if backend == "FusedAttention": if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs: if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
...@@ -1146,14 +1197,18 @@ def _run_dot_product_attention( ...@@ -1146,14 +1197,18 @@ def _run_dot_product_attention(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
) )
if is_training: if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset) return (
out_orig,
max_logit,
(q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset),
)
else: else:
return out_orig, (None, None, None, d_softmax_offset) return out_orig, max_logit, (None, None, None, d_softmax_offset)
else: else:
if is_training: if is_training:
return out, (q.grad, k.grad, v.grad, d_softmax_offset) return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
else: else:
return out, (None, None, None, d_softmax_offset) return out, max_logit, (None, None, None, d_softmax_offset)
model_configs_te_layer = { model_configs_te_layer = {
...@@ -1527,8 +1582,7 @@ model_configs_fp8_extra_state = { ...@@ -1527,8 +1582,7 @@ model_configs_fp8_extra_state = {
} }
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"]) @pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
...@@ -1690,8 +1744,7 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] ...@@ -1690,8 +1744,7 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16) @pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
...@@ -1927,8 +1980,7 @@ def _run_mha_fp8_vs_f16( ...@@ -1927,8 +1980,7 @@ def _run_mha_fp8_vs_f16(
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
...@@ -2256,8 +2308,7 @@ models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"] ...@@ -2256,8 +2308,7 @@ models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
), ),
reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""", reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
) )
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8) @pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0) @pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model): def test_custom_mha_fp8_vs_f16(dtype, model):
......
...@@ -138,8 +138,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -138,8 +138,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn = { model_configs_fused_attn = {
# test: ModelConfig(b, sq, hq, dqk) # test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_logit=True), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_logit=True), # MHA
"cp_1_2": ModelConfig( "cp_1_2": ModelConfig(
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), # MHA ), # MHA
...@@ -184,7 +184,7 @@ dtypes = ["bf16", "fp16", "fp8"] ...@@ -184,7 +184,7 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats = ["bshd", "sbhd", "thd"] qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential: if test_essential:
configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"] dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"] qkv_formats = ["sbhd", "thd"]
......
...@@ -685,11 +685,12 @@ if __name__ == "__main__": ...@@ -685,11 +685,12 @@ if __name__ == "__main__":
pass pass
else: else:
test_log_expert_parallel() test_log_expert_parallel()
if fp8_available:
for parallel_mode in ["column", "row"]: for parallel_mode in ["column", "row"]:
for gather_weight in [True, False]: for gather_weight in [True, False]:
test_log_distributed(parallel_mode, gather_weight) test_log_distributed(parallel_mode, gather_weight)
if fp8_available:
for parallel_mode in ["row", "column"]: for parallel_mode in ["row", "column"]:
test_disable_fp8_layer(parallel_mode) test_disable_fp8_layer(parallel_mode)
......
test_switch_to_nondebug_mode:
enabled: True
layers:
layer_name_regex_pattern: .*
transformer_engine:
TestDummyFeature:
enabled: True
inspect_only_once: True
tensors: [weight, activation, gradient, output, wgrad, dgrad]
gemms: [wgrad, dgrad, fprop]
...@@ -18,7 +18,11 @@ from transformer_engine.pytorch import ( ...@@ -18,7 +18,11 @@ from transformer_engine.pytorch import (
) )
from transformer_engine.pytorch.quantization import RecipeState from transformer_engine.pytorch.quantization import RecipeState
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.debug.features.utils.stats_computation import (
compute_max_blockwise_dynamic_range,
BlockwiseDynamicRangeStat,
)
import math
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
...@@ -154,7 +158,7 @@ fp8_recipes = [ ...@@ -154,7 +158,7 @@ fp8_recipes = [
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_numerics(fp8_recipe, feature_dirs): def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs):
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling(): if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling():
...@@ -210,6 +214,107 @@ def test_numerics(fp8_recipe, feature_dirs): ...@@ -210,6 +214,107 @@ def test_numerics(fp8_recipe, feature_dirs):
assert overflows == pytest.approx(expected.cpu(), abs=1e-4) assert overflows == pytest.approx(expected.cpu(), abs=1e-4)
LOG_HIGH_PRECISION_CONFIG = """
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
stats:
- dynamic_range
- max_blockwise_dynamic_range:
block_size: 4
dims: 1
- max_blockwise_dynamic_range:
block_size: 4
dims: 2
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
@pytest.mark.parametrize("tensor_name", ["activation", "weight", "gradient"])
def test_log_stats_numerics(feature_dirs, tensor_name):
"""Check correctness of dynamic range and max blockwise dynamic range stats.
Tests different tensor types:
- activation/weight: use both orientations (rowwise + columnwise), takes max
- gradient/dgrad: use single orientation (rowwise only)
"""
log_only_bare_stats_config = LOG_HIGH_PRECISION_CONFIG
with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir:
# There is 1024 x 1024 tensor with very small epsilon values in almost all elements,
# one row of large value A and three rows of large value B.
epsilon = 1e-10
A = 1000
B = 50
tensor = torch.zeros(1024, 1024).cuda() + epsilon
tensor[0, :] = A
tensor[1:4, :] = B
debug_api.transformer_engine.inspect_tensor(
layer_name="layer_name",
tensor_name=tensor_name,
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=None,
rowwise_quantized_tensor=None,
columnwise_quantized_tensor=None,
)
debug_api.step()
output = read_log(log_dir)
max_over_orientations = tensor_name in ["activation", "weight"]
max_over_orientations_suffix = "_max_over_orientations" if max_over_orientations else ""
# Track which stats were found to ensure all are present
found_dims_1 = False
found_dims_2 = False
found_dynamic_range = False
for line in output.splitlines():
if f"max_blockwise_dynamic_range_block_size_4_dims_1{max_over_orientations_suffix}" in line:
max_blockwise_dynamic_range_block_size_4_dims_1 = float(line.split("value=")[1])
if max_over_orientations:
# Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B)
expected = math.log2(A) - math.log2(B)
else:
# Rowwise blocks have uniform values -> dynamic_range = 0
expected = 0
assert max_blockwise_dynamic_range_block_size_4_dims_1 == pytest.approx(
expected, abs=1e-4
)
found_dims_1 = True
elif (
f"max_blockwise_dynamic_range_block_size_4_dims_2{max_over_orientations_suffix}" in line
):
max_blockwise_dynamic_range_block_size_4_dims_2 = float(line.split("value=")[1])
# For 2D blocks (4x4 tiles), blocks always contain mixed values from different rows
expected = math.log2(A) - math.log2(B)
assert max_blockwise_dynamic_range_block_size_4_dims_2 == pytest.approx(
expected, abs=1e-4
)
found_dims_2 = True
elif "_dynamic_range" in line and "max_blockwise_dynamic_range" not in line:
dynamic_range = float(line.split("value=")[1])
expected = math.log2(A) - math.log2(epsilon)
assert dynamic_range == pytest.approx(expected, abs=1e-4)
found_dynamic_range = True
# Ensure all expected stats were found in the output
assert found_dims_1, "max_blockwise_dynamic_range (dims=1) not found in output"
assert found_dims_2, "max_blockwise_dynamic_range (dims=2) not found in output"
assert found_dynamic_range, "dynamic_range not found in output"
@pytest.mark.parametrize("layer", ["linear", "transformer"]) @pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
if not fp8_available: if not fp8_available:
...@@ -256,3 +361,92 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): ...@@ -256,3 +361,92 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
debug_api.end_debug() debug_api.end_debug()
TEDebugState._reset() TEDebugState._reset()
def test_compute_max_blockwise_dynamic_range_direct():
"""Direct unit test for compute_max_blockwise_dynamic_range function.
Tests the function with various configurations to ensure correct behavior
for different block sizes, dimensions, and orientation settings.
"""
# Create test tensor with uniform rows but mixed columns
# Row 0: all 1000, Row 1-3: all 50, remaining: all 0.01
epsilon = 0.01
A = 1000.0
B = 50.0
tensor = torch.zeros(1024, 1024).cuda() + epsilon
tensor[0, :] = A
tensor[1:4, :] = B
# Test 1: dims=1, max_over_orientations=False (rowwise only)
# Rowwise blocks have uniform values -> dynamic_range should be 0
stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=False)
result = compute_max_blockwise_dynamic_range(tensor, stat_config)
assert result.item() == pytest.approx(
0.0, abs=1e-4
), "Rowwise 1D blocks with uniform values should have dynamic_range=0"
# Test 2: dims=1, max_over_orientations=True (max of rowwise and columnwise)
# Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B)
stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True)
result = compute_max_blockwise_dynamic_range(tensor, stat_config)
expected = math.log2(A) - math.log2(B)
assert result.item() == pytest.approx(expected, abs=1e-4), (
f"Max over orientations should capture columnwise dynamic_range, expected {expected}, got"
f" {result.item()}"
)
# Test 3: dims=2, block_size=4 (4x4 tiles)
# 2D blocks span multiple rows -> always have mixed values
stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=2, max_over_orientations=False)
result = compute_max_blockwise_dynamic_range(tensor, stat_config)
expected = math.log2(A) - math.log2(B)
assert result.item() == pytest.approx(expected, abs=1e-4), (
f"2D blocks should capture mixed values from different rows, expected {expected}, got"
f" {result.item()}"
)
# Test 4: Different block size
# With block_size=8, columnwise blocks contain [A, B, B, B, epsilon, epsilon, epsilon, epsilon]
# So max=A, min=epsilon (not B anymore)
stat_config = BlockwiseDynamicRangeStat(block_size=8, dims=1, max_over_orientations=True)
result = compute_max_blockwise_dynamic_range(tensor, stat_config)
expected = math.log2(A) - math.log2(epsilon) # min is epsilon, not B
assert result.item() == pytest.approx(
expected, abs=1e-4
), f"Block size 8 should work correctly, expected {expected}, got {result.item()}"
# Test 5: Tensor with all uniform values -> dynamic_range should be 0
uniform_tensor = torch.ones(64, 64).cuda() * 42.0
stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True)
result = compute_max_blockwise_dynamic_range(uniform_tensor, stat_config)
assert result.item() == pytest.approx(
0.0, abs=1e-4
), "Uniform tensor should have dynamic_range=0"
# Test 6: 3D tensor flattening validation using 2D/3D comparison
# Create a 4x4 tensor with distinct 2x2 blocks, compute with dims=2, block_size=2
# Then reshape to 3D and compute again - results should match if flattening is correct
tensor_2d = torch.tensor(
[
[1.0, 1.0, 10.0, 10.0],
[1.0, 1.0, 10.0, 10.0],
[100.0, 100.0, 1000.0, 1000.0],
[100.0, 100.0, 1000.0, 1000.0],
]
).cuda()
# Compute on 2D tensor: 4 blocks of 2x2, max range is log2(1000/100)
stat_config = BlockwiseDynamicRangeStat(block_size=2, dims=2, max_over_orientations=False)
result_2d = compute_max_blockwise_dynamic_range(tensor_2d, stat_config)
# Reshape to 3D [2, 2, 4] and compute - should give same result if flattening is correct
tensor_3d = tensor_2d.reshape(2, 2, 4)
result_3d = compute_max_blockwise_dynamic_range(tensor_3d, stat_config)
assert result_2d.item() == pytest.approx(result_3d.item(), abs=1e-6), (
"3D tensor [2,2,4] flattened to [4,4] must give same result as original 2D, got"
f" 2D={result_2d.item()}, 3D={result_3d.item()}"
)
print("All direct tests for compute_max_blockwise_dynamic_range passed!")
...@@ -6,71 +6,70 @@ ...@@ -6,71 +6,70 @@
import pytest import pytest
import torch import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import time
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs): @pytest.mark.parametrize("use_microbatching", [False, True])
debug_api.end_debug() def test_layer_switches_to_nondebug_mode(configs_dir, feature_dirs, use_microbatching):
TEDebugState._reset() """
if debug_tools_initialized: Test that layers switch to non-debug mode when no features are active.
# This config log stats starting from 0, every N iterations for huge N >> NUM_ITERS.
# So after 1 warm-up iteration, this layers should work in non-debug mode.
debug_api.initialize(
config_file=configs_dir + "/perf_config.yaml", feature_dirs=feature_dirs
)
try: Uses TestDummyFeature with inspect_only_once=True, which makes inspect_tensor_enabled return (False, None).
if layer == "linear": The TE should:
model = torch.nn.Sequential( 1. Call inspect_tensor_enabled to check if feature is needed
te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2") 2. Never call inspect_tensor
).cuda() 3. Allow layers to switch to non-debug mode for optimal performance,
NUM_ITERS = 18000 so that inspect_tensor_enabled is never called again.
elif layer == "transformer":
model = torch.nn.Sequential(
te.TransformerLayer(1, 1, 1, name="transformer1"),
te.TransformerLayer(1, 1, 1, name="transformer2"),
).cuda()
NUM_ITERS = 2000
x = torch.randn(1, 1, 1).cuda() Tests both with and without microbatching to ensure proper behavior in both scenarios.
"""
try:
debug_api.initialize(
config_file=configs_dir + "/test_switch_to_nondebug_mode.yaml",
feature_dirs=feature_dirs,
)
import transformer_engine.debug.features._test_dummy_feature as dummy_feature
# Reset counters
dummy_feature._inspect_tensor_enabled_call_count = 0
dummy_feature._inspect_tensor_call_count = 0
model = te.Linear(256, 256, name="test_linear").cuda()
x = torch.randn(8, 256, 256).cuda()
# Run multiple iterations
for i in range(20):
if use_microbatching:
# Alternate between first and non-first microbatch
is_first_microbatch = i % 2 == 0
y = model(x, is_first_microbatch=is_first_microbatch)
else:
# Run without specifying is_first_microbatch
y = model(x) y = model(x)
y.sum().backward() y.sum().backward()
debug_api.step() debug_api.step()
torch.cuda.synchronize()
time_start = time.time() # Verify inspect_tensor_enabled was called only once per tensor
for i in range(NUM_ITERS): # (activation, weight, gradient, output, wgrad, dgrad)
y = model(x) enabled_call_count = dummy_feature._inspect_tensor_enabled_call_count
y.sum().backward() microbatch_info = "with microbatching" if use_microbatching else "without microbatching"
if debug_tools_initialized: assert enabled_call_count == 6, (
debug_api.step() f"inspect_tensor_enabled was called {enabled_call_count} times ({microbatch_info}), "
torch.cuda.synchronize() "but should be called 6 times to check if feature is needed for each tensor "
time_end = time.time() "(activation, weight, gradient, output, wgrad, dgrad)"
)
# Verify inspect_tensor was never called - it should not be called if inspect_tensor_enabled returns (False, None)
inspect_call_count = dummy_feature._inspect_tensor_call_count
assert inspect_call_count == 0, (
f"inspect_tensor was called {inspect_call_count} times ({microbatch_info}), "
"but should never be called when inspect_tensor_enabled returns (False, None)"
)
finally: finally:
if debug_tools_initialized:
debug_api.end_debug() debug_api.end_debug()
TEDebugState._reset()
return time_end - time_start
@pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_cpu_overhead(layer, configs_dir, feature_dirs):
# runs one layer many times on very small tensor
# - gpu time should be negligible, so time should be dominated by cpu time.
# if layers does not invoke any feature in current iteration,
# then it changed into non-debug mode and should not have any non-negligible cpu overhead
# compared to layer without debug tools initialized.
with_debug_tools = _run_cpu_overhead(True, layer, configs_dir, feature_dirs)
without_debug_tools = _run_cpu_overhead(False, layer, configs_dir, feature_dirs)
print(f"with_debug_tools: {with_debug_tools} s")
print(f"without_debug_tools: {without_debug_tools} s")
assert with_debug_tools < without_debug_tools * 1.25 # 25% overhead margin
...@@ -8,58 +8,74 @@ import os ...@@ -8,58 +8,74 @@ import os
import sys import sys
import argparse import argparse
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import (
Format,
DelayedScaling,
Float8CurrentScaling,
MXFP8BlockScaling,
)
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed.tensor import DTensor
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, optim from torch import nn, optim
from torch.distributed import DeviceMesh from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh from torch.distributed.device_mesh import init_device_mesh
from transformer_engine.pytorch import QuantizedTensor
from contextlib import nullcontext from contextlib import nullcontext
import transformer_engine.pytorch as te LOCAL_RANK = None
from transformer_engine.common.recipe import Format, DelayedScaling
class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNet, self).__init__()
self.fc1 = te.Linear(input_size, hidden_size)
self.fc2 = te.Linear(hidden_size, output_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items()}
return custom_attrs
def restore_custom_attrs(module, custom_attrs): def dist_print(msg):
for name, param in module.named_parameters(): if LOCAL_RANK == 0:
if name in custom_attrs: print(msg)
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)
def _parse_args(argv=None, namespace=None): def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()") parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()")
parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model") parser.add_argument("--num-heads", type=int, default=8, help="Number of attn. heads")
parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size") parser.add_argument("--head-dim", type=int, default=64, help="Attention head size")
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") parser.add_argument("--batch-size", type=int, default=16, help="Batch size of input")
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") parser.add_argument("--seq-length", type=int, default=128, help="Sequence length of input")
parser.add_argument("--params-dtype", type=str, default="float32", help="Parameter dtype.")
parser.add_argument( parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
) )
parser.add_argument(
"--recipe",
type=str,
default="mx_fp8_block_scaling",
help="Quantizer type.",
choices=["delayed_scaling", "current_scaling", "mx_fp8_block_scaling"],
)
parser.add_argument(
"--layer-type",
type=str,
default="TransformerLayer",
choices=[
"Linear",
"LayerNormLinear",
"LayerNormMLP",
"MultiheadAttention",
"TransformerLayer",
],
help="Transformer Engine layer type",
)
parser.add_argument("--num-layers", type=int, default=4, help="Number of layers in the model")
parser.add_argument( parser.add_argument(
"--iter", type=int, default=10, help="Number of iterations for forward pass" "--iter", type=int, default=10, help="Number of iterations for forward pass"
) )
parser.add_argument(
"--device",
type=str,
default="meta",
help="Device to run the model on.",
choices=["cuda", "meta"],
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
# Adding hsdp_dim as a list argument, comma-separated # Adding hsdp_dim as a list argument, comma-separated
parser.add_argument( parser.add_argument(
...@@ -74,10 +90,170 @@ def _parse_args(argv=None, namespace=None): ...@@ -74,10 +90,170 @@ def _parse_args(argv=None, namespace=None):
return args return args
sub_modules_to_wrap = [te.Linear] ## Methods to help initialize the TE model in an FSDP2 setting
## with required configurations based on command line args
def get_te_layer_from_string(layer_name):
te_layer_types = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
te_layer_names = [layer.__name__ for layer in te_layer_types]
te_layer_map = dict(zip([name.lower() for name in te_layer_names], te_layer_types))
if layer_name.lower() not in te_layer_map.keys():
raise argparse.ArgumentTypeError(
f'"{layer_name}" is not a valid Transformer Engine layer, '
f"please choose layer from {te_layer_names}."
)
return te_layer_map[layer_name.lower()]
def get_recipe_from_string(recipe, fp8_format=Format.HYBRID):
if recipe == "delayed_scaling":
return DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
elif recipe == "current_scaling":
return Float8CurrentScaling(fp8_format=fp8_format)
elif recipe == "mx_fp8_block_scaling":
return MXFP8BlockScaling(fp8_format=fp8_format)
else:
raise ValueError(f"Unknown quantizer type: {recipe}")
def init_te_model(config):
hidden_size = config.num_heads * config.head_dim
args = [hidden_size, hidden_size]
inp_shape = [config.seq_length, config.batch_size, hidden_size]
out_shape = [config.seq_length, config.batch_size, hidden_size]
if config.params_dtype == "float16":
params_dtype = torch.float16
elif config.params_dtype == "bfloat16":
params_dtype = torch.bfloat16
else:
params_dtype = torch.float32
kwargs = {
"params_dtype": params_dtype,
}
kwargs["device"] = config.device
layer_type = get_te_layer_from_string(config.layer_type)
# We are creating model in a way so that we can test both reshard_after_forward=True/False cases.
# more details below.
if layer_type in [te.MultiheadAttention, te.TransformerLayer]:
# For this case, we are creating a model that resemebles production use-cases
# wherein there are mltiple TransformerLayers in the model. And we would need
# to shard each transformer layer. Since each transformer layer is not a root module,
# FSDP2's fully_shard assigns reshard_after_forward=False for all parameters of the model.
args[1] *= 4 # FFN hidden size
args.append(config.num_heads)
kwargs["fuse_qkv_params"] = True
if layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
model = nn.Sequential(*[layer_type(*args, **kwargs) for _ in range(config.num_layers)])
elif layer_type == te.LayerNormLinear:
# For this case, we are creating a model with just one LayerNormLinear layer
# so that the model itself is a root module, and FSDP2's fully_shard assigns
# reshard_after_forward=True for the parameters of these model.
args[1] *= 3 # QKV projection
out_shape[-1] *= 3
model = layer_type(*args, **kwargs)
else:
model = layer_type(*args, **kwargs)
return model, inp_shape, out_shape
def get_device_mesh(world_size, sharding_dims):
dist_print(f"sharding-dims:{sharding_dims}")
device_ids = list(range(world_size))
if sharding_dims is None: # FSDP
mesh = DeviceMesh("cuda", device_ids)
elif len(sharding_dims) == 1:
assert sharding_dims[0] == world_size
mesh = DeviceMesh("cuda", device_ids)
elif len(sharding_dims) == 2: # HSDP
assert sharding_dims[0] * sharding_dims[1] == world_size
mesh = init_device_mesh(
"cuda",
(sharding_dims[0], sharding_dims[1]),
mesh_dim_names=("replicate", "shard"),
)
else:
assert False
return mesh
def shard_model_with_fsdp2(model, mesh):
for child in model.children():
fully_shard(child, mesh=mesh)
fully_shard(model, mesh=mesh)
return model
#### Methods to save the custom attributes of QuantizedTensors before sharding
#### them with FSDP2, and restore them after sharding.
def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
if isinstance(param, QuantizedTensor):
# Ignore FP8 metadata attributes. Otherwise we will save duplicate copies
# for data/transpose FP8 tensors on top of FP8 tensors that FSDP2 will save.
ignore_keys = [key for key in param.__dict__.keys() if key.startswith("_")]
else:
ignore_keys = []
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items() if k not in ignore_keys}
return custom_attrs
def restore_custom_attrs(module, custom_attrs):
for name, param in module.named_parameters():
if name in custom_attrs:
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)
@torch.no_grad()
def test_fp8_fsdp2_allgather(model):
# Do manual allgather in fp32 and match against fp8 allgather done
# with fsdp2
# FP32 manual weight allgather
fp32_allgathered_params = {}
for name, param in model.named_parameters():
assert isinstance(param, DTensor)
local_tensor = param._local_tensor
device_mesh = param.device_mesh
dist_group = (
device_mesh.get_group(mesh_dim="shard")
if device_mesh.ndim > 1
else device_mesh.get_group()
)
# Perform manual allgather on local_tensor. zeros_like will create hp tensor since torch_dispatch
# for local_tensor will go down the dequantization route.
gathered_tensor = [
torch.zeros_like(local_tensor) for _ in range(dist.get_world_size(group=dist_group))
]
dist.all_gather(gathered_tensor, local_tensor.dequantize(), group=dist_group)
full_tensor = torch.cat(gathered_tensor, dim=0)
fp32_allgathered_params[name] = full_tensor
# FP8 allgather using FSDP2
for module in model.modules():
# Not all modules are wrapped/sharded with FSDP2.
if hasattr(module, "unshard"):
module.unshard()
# Make sure allgathered parameters match exactly
for name, param in model.named_parameters():
assert torch.allclose(param.dequantize(), fp32_allgathered_params[name])
# Revert model to original sharded state
for module in model.modules():
# Not all modules are wrapped/sharded with FSDP2.
if hasattr(module, "reshard"):
module.reshard()
def _train(args): def _train(args):
global LOCAL_RANK
assert "TORCHELASTIC_RUN_ID" in os.environ assert "TORCHELASTIC_RUN_ID" in os.environ
WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
...@@ -103,77 +279,69 @@ def _train(args): ...@@ -103,77 +279,69 @@ def _train(args):
# FP8 Configuration # FP8 Configuration
fp8_format = Format.HYBRID fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") fp8_recipe = get_recipe_from_string(args.recipe, fp8_format)
build_model_context_args = {}
if not args.fp8_init: if not args.fp8_init:
# Build model context (FP8 init) # Build model context (FP8 init)
build_model_context = nullcontext build_model_context = nullcontext
build_model_context_args = {} else:
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch import quantized_model_init
build_model_context = quantized_model_init build_model_context = fp8_model_init
build_model_context_args["enabled"] = True build_model_context_args["enabled"] = True
build_model_context_args["recipe"] = fp8_recipe
# Build the model with the specified context dist_print(f"Memory before model init: {torch.cuda.memory_allocated(device)/1e6} MB")
# Create the model on the meta/cuda device as per args
with build_model_context(**build_model_context_args): with build_model_context(**build_model_context_args):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size) model, inp_shape, out_shape = init_te_model(args)
else: dist_print(
model = SimpleNet(args.input_size, args.hidden_size, args.output_size) f"Memory after model init on device {args.device}:"
# Move the model to the correct device f" {torch.cuda.memory_allocated(device)/1e6} MB"
)
model.to(device)
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...")
# Creating a DeviceMesh for fully_shard # Creating a DeviceMesh for fully_shard
world_size = int(WORLD_SIZE) world_size = int(WORLD_SIZE)
device_ids = list(range(world_size))
if LOCAL_RANK == 0:
print(f"sharding-dims:{args.sharding_dims}")
# Setup the sharding mesh for FSDP/HSDP # Setup the sharding mesh for FSDP/HSDP
if args.sharding_dims == None: # FSDP mesh = get_device_mesh(world_size, args.sharding_dims)
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 1:
assert args.sharding_dims[0] == device_ids[-1] + 1
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 2: # HSDP
assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1
mesh = init_device_mesh(
"cuda",
(args.sharding_dims[0], args.sharding_dims[1]),
mesh_dim_names=("replicate", "shard"),
)
else:
assert False
# Apply FSDP/HSDP
custom_attrs = save_custom_attrs(model) custom_attrs = save_custom_attrs(model)
for sub_module in model.modules(): model = shard_model_with_fsdp2(model, mesh)
if any(
isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap
):
fully_shard(sub_module, mesh=mesh)
fully_shard(model, mesh=mesh)
restore_custom_attrs(model, custom_attrs) restore_custom_attrs(model, custom_attrs)
# model now has DTensors as its parameters
if args.device == "meta":
# After FSDP2 has been applied, materialize and initialize the sharded parameters
# TE base.py's reset_parameters() handles DTensors with FP8 initialization
for module in model.modules():
if hasattr(module, "reset_parameters"):
module.reset_parameters()
dist_print(f" Sharded parameters materialized and initialized on cuda device.")
dist_print(
f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB"
)
optimizer = optim.Adam(model.parameters(), lr=1e-3) optimizer = optim.Adam(model.parameters(), lr=1e-3)
for iteration in range(args.iter): for iteration in range(args.iter):
# Zero the parameter gradients # Zero the parameter gradients
optimizer.zero_grad() optimizer.zero_grad()
input_data = torch.randn(args.batch_size, args.input_size).to(device) input_data = torch.randn(inp_shape).to(device)
with te.autocast(enabled=True, recipe=fp8_recipe):
output = model(input_data) output = model(input_data)
target = torch.randn(args.batch_size, args.output_size).to(device) target = torch.randn(out_shape).to(device)
loss = F.mse_loss(output, target) loss = F.mse_loss(output, target)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if LOCAL_RANK == 0: dist_print(f"Iteration {iteration} completed with loss {loss.item()}")
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")
# Some of the FSDP states are lazy initialized during FSDP forward pass
# so testing fp8 allgather at the end of the training loop.
if args.fp8_init:
test_fp8_fsdp2_allgather(model)
dist.destroy_process_group() dist.destroy_process_group()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Done...")
return 0 return 0
......
...@@ -22,8 +22,8 @@ from transformer_engine.common.recipe import ( ...@@ -22,8 +22,8 @@ from transformer_engine.common.recipe import (
) )
from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.experimental import quantization_nvfp4 from transformer_engine.pytorch.custom_recipes import quantization_nvfp4
from transformer_engine.pytorch.experimental import utils from transformer_engine.pytorch.custom_recipes import utils
from run_layer_with_overlap import _compare_tensors from run_layer_with_overlap import _compare_tensors
...@@ -486,7 +486,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): ...@@ -486,7 +486,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
sequence_parallel (bool): Enable sequence parallelism if True. sequence_parallel (bool): Enable sequence parallelism if True.
kwargs (dict): Additional arguments for the linear layer. kwargs (dict): Additional arguments for the linear layer.
QUANTIZATION options: nvfp4 <=> experimental nvfp4 as a reference QUANTIZATION options: nvfp4 <=> custom nvfp4 as a reference
""" """
params_dtype = torch.bfloat16 params_dtype = torch.bfloat16
use_bias = kwargs.get("bias", True) use_bias = kwargs.get("bias", True)
......
...@@ -34,6 +34,7 @@ from transformer_engine.pytorch import ( ...@@ -34,6 +34,7 @@ from transformer_engine.pytorch import (
Float8Tensor, Float8Tensor,
) )
# Import utility functions # Import utility functions
_current_file = pathlib.Path(__file__).resolve() _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent)) sys.path.append(str(_current_file.parent.parent))
......
...@@ -14,7 +14,7 @@ import transformer_engine.pytorch as te ...@@ -14,7 +14,7 @@ import transformer_engine.pytorch as te
Distributed numerics tests Distributed numerics tests
This numerical test aims for zero tolerance test for absolute confidence in numerics. This numerical test aims for zero tolerance test for absolute confidence in numerics.
In the case of NVFP4, with the experimental NVFP4 quantization, we matched bitwise In the case of NVFP4, with the custom NVFP4 quantization, we matched bitwise
result with the native silicon. For distrbuted test cases, we can do the same by thing result with the native silicon. For distrbuted test cases, we can do the same by thing
by comparing BF16 AG results with the low precision AG results at layer level. by comparing BF16 AG results with the low precision AG results at layer level.
""" """
......
...@@ -12,22 +12,26 @@ import transformer_engine.pytorch as te ...@@ -12,22 +12,26 @@ import transformer_engine.pytorch as te
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
NUM_PROCS: int = torch.cuda.device_count() NUM_PROCS: int = torch.cuda.device_count()
def _run_test(fp_init, sharding_dims): def _run_test(fp_init, sharding_dims, recipe, layer_type):
test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py" test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py"
test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)] test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)]
if fp_init: if fp_init:
test_cmd += ["--fp8-init"] test_cmd += ["--fp8-init"]
if len(sharding_dims) == 1: if len(sharding_dims) == 1:
test_cmd += ["--sharding-dims", str(sharding_dims[0])] test_cmd += ["--sharding-dims", str(sharding_dims[0])]
elif len(sharding_dims) == 2: elif len(sharding_dims) == 2:
test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])]
else: else:
assert False assert False
test_cmd += ["--recipe", recipe]
test_cmd += ["--layer-type", layer_type]
result = subprocess.run(test_cmd, env=os.environ, check=True) result = subprocess.run(test_cmd, env=os.environ, check=True)
...@@ -36,16 +40,20 @@ def _run_test(fp_init, sharding_dims): ...@@ -36,16 +40,20 @@ def _run_test(fp_init, sharding_dims):
@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") @pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) @pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2]))
@pytest.mark.parametrize("fp8_init", (False, True)) @pytest.mark.parametrize("fp8_init", (False, True))
def test_distributed(fp8_init, sharding_dims): @pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling"))
@pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer"))
def test_distributed(fp8_init, sharding_dims, recipe, layer_type):
# Skip invalid configurations # Skip invalid configurations
if torch.cuda.device_count() < 4: if torch.cuda.device_count() < 4:
pytest.skip("FSDP2 test requires at least 4 GPUs") pytest.skip("FSDP2 test requires at least 4 GPUs")
if fp8_init and not fp8_available: if recipe == "mx_fp8_block_scaling" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
elif not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
_run_test(fp8_init, sharding_dims) _run_test(fp8_init, sharding_dims, recipe, layer_type)
def test_dummy() -> None: def test_dummy() -> None:
......
...@@ -8,8 +8,8 @@ import transformer_engine.pytorch as te ...@@ -8,8 +8,8 @@ import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils from transformer_engine.pytorch.custom_recipes import utils
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
......
...@@ -6,8 +6,8 @@ import pytest ...@@ -6,8 +6,8 @@ import pytest
import torch import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch.experimental import quantization_nvfp4 from transformer_engine.pytorch.custom_recipes import quantization_nvfp4
from transformer_engine.pytorch.experimental import utils from transformer_engine.pytorch.custom_recipes import utils
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
......
...@@ -7,10 +7,10 @@ import torch ...@@ -7,10 +7,10 @@ import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.common.recipe import NVFP4BlockScaling from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.experimental import utils
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.common.recipe import NVFP4BlockScaling from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.experimental import utils
import pytest import pytest
import torch import torch
......
This diff is collapsed.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import contextlib
import gc
import os
from typing import Iterable, Optional
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
from utils import ModelConfig, get_available_attention_backends
# Check supported quantization schemes
fp8_available = te.is_fp8_available()
mxfp8_available = te.is_mxfp8_available()
quantization_recipes: Optional[recipe.Recipe] = [None]
if fp8_available:
quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling()))
model_config = {
"small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
}
SIZE = model_config["small"].hidden_size
NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps
# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
assert os.getenv("NVTE_FLASH_ATTN") == "0"
# CPU offload v1 code path is enabled
assert os.environ.get("NVTE_CPU_OFFLOAD_V1", "0") == "1"
# Offloading is supported for attention only for fused and flash attention backends,
# so the use of bfloat16 is required.
#
# For the TransformerLayer, activation offloading with dropout is not supported,
# so we set hidden_dropout to 0.0.
model_types = {
"linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16),
"multihead_attention": lambda: te.MultiheadAttention(
SIZE, NUM_HEADS, params_dtype=torch.bfloat16
),
"transformer_layer": lambda: te.TransformerLayer(
SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
),
"linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
"layernorm_mlp_ops": lambda: te.ops.Sequential(
te.ops.LayerNorm(SIZE, dtype=torch.bfloat16),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
te.ops.GELU(),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
),
}
def _make_input() -> torch.Tensor:
"""Generate random input tensor."""
return torch.randn(
(128, SIZE, SIZE),
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
def _warmup_model(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> None:
"""Perform forward and backward pass"""
tensor = _make_input()
for module in modules:
with te.autocast(
enabled=quantization_recipe is not None,
recipe=quantization_recipe,
):
tensor = module(tensor)
tensor.sum().backward()
def _estimate_cached_weight_size(
model_name: str,
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> float:
"""Calculate the memory (in MiB) needed for weight caching."""
# The weight params are cached directly for unquantized compute
if quantization_recipe is None:
return 0
# Count number of weight param elements
param_elements = 0
for module in modules:
for param in module.parameters():
if param.dim() == 2:
param_elements += param.numel()
# FP8 tensor-scaling caches one byte per element
if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling():
if not is_non_tn_fp8_gemm_supported() and model_name not in (
"linear_op",
"layernorm_mlp_ops",
):
# Modules do not deallocate FP8 transpose for weights
return 2 * param_elements / 1024**2
return param_elements / 1024**2
# MXFP8 caches one data byte per element and one scale byte per 32
# elements
if quantization_recipe.mxfp8():
if model_name not in ("linear_op", "layernorm_mlp_ops"):
# Modules do not deallocate column-wise MXFP8 data for weights
return 2 * param_elements * (1 + 1 / 32) / 1024**2
return param_elements * (1 + 1 / 32) / 1024**2
raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})")
def _measure_cached_memory(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
cpu_offload: bool,
) -> float:
"""Measure the growth in allocated GPU memory in MiB after a model forward pass.
Memory measurement excludes the input and output tensors.
"""
# Reset memory
gc.collect()
torch.cuda.empty_cache()
# Context and sync function for CPU offloading
if cpu_offload:
offload_context, sync_function = te.get_cpu_offload_context(
enabled=True,
num_layers=len(modules),
model_layers=len(modules) + 1,
offload_activations=True,
offload_weights=False,
)
else:
offload_context = contextlib.nullcontext()
sync_function = lambda x: x
# Forward pass, with dummy step to trigger offload for last module
inp = _make_input()
tensor = inp
memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
for module in modules:
with te.autocast(
enabled=quantization_recipe is not None, recipe=quantization_recipe
), offload_context:
tensor = module(tensor)
tensor = sync_function(tensor)
with offload_context:
tensor = tensor.clone()
tensor = sync_function(tensor)
memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2)
# Backward pass
tensor.sum().backward()
torch.cuda.synchronize()
# Memory usage in MiB
return memory_after_forward - memory_before_forward
@pytest.mark.parametrize("quantization_recipe", quantization_recipes)
@pytest.mark.parametrize("model_name", model_types.keys())
def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None:
"""Check that CPU offloading runs and has expected memory usage."""
# Construct model
modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)]
if model_name in ["multihead_attention", "transformer_layer"]:
available_backends, *_ = get_available_attention_backends(
model_config["small"],
qkv_dtype=torch.bfloat16,
qkv_layout="sbhd_sbhd_sbhd",
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("Fused attention backend not available.")
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
# Warmup
_warmup_model(modules_list, quantization_recipe)
# Measure cached memory after forward pass
memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False)
memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True)
# Check for expected memory usage
assert memory_with_offload < memory_without_offload
memory_from_cached_weights = _estimate_cached_weight_size(
model_name,
modules_list,
quantization_recipe,
)
assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from typing import Iterable, List, Union from typing import Callable, Dict, Iterable, List, Tuple, Union
import pytest import pytest
import torch import torch
...@@ -173,6 +173,20 @@ def get_outputs( ...@@ -173,6 +173,20 @@ def get_outputs(
return values return values
def reset_graphs(
graphed_callables: Union[Callable, Tuple[Callable, ...], Dict[Tuple[int, int], Callable]],
) -> None:
"""Reset CUDA graphs."""
if isinstance(graphed_callables, tuple) or isinstance(graphed_callables, list):
for callable in graphed_callables:
callable.reset()
elif isinstance(graphed_callables, dict):
for callable in graphed_callables.values():
callable.reset()
else:
graphed_callables.reset()
class _Sequential(torch.nn.Sequential): class _Sequential(torch.nn.Sequential):
"""Sequential model that forwards keyword arguments to modules""" """Sequential model that forwards keyword arguments to modules"""
...@@ -335,7 +349,12 @@ def _test_cuda_graphs( ...@@ -335,7 +349,12 @@ def _test_cuda_graphs(
output.backward(grad_output) output.backward(grad_output)
optimizer.step() optimizer.step()
return get_outputs(model, output) outputs = get_outputs(model, output)
if graph_mode == "full":
reset_graphs(model)
elif graph_mode == "individual":
reset_graphs(modules)
return outputs
@pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("module", _test_cuda_graphs_modules)
...@@ -487,7 +506,10 @@ def _test_cuda_graphs_with_dot_product_attention( ...@@ -487,7 +506,10 @@ def _test_cuda_graphs_with_dot_product_attention(
output = model(*inputs) output = model(*inputs)
output.backward(grad_output) output.backward(grad_output)
return get_outputs(model, output) outputs = get_outputs(model, output)
if with_graph:
reset_graphs(model)
return outputs
@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dtype", dtypes)
...@@ -572,7 +594,10 @@ def _test_cuda_graphs_with_kwargs( ...@@ -572,7 +594,10 @@ def _test_cuda_graphs_with_kwargs(
output.backward(grad_output) output.backward(grad_output)
optimizer.step() optimizer.step()
return get_outputs(model, output) outputs = get_outputs(model, output)
if with_graph:
reset_graphs(model)
return outputs
def test_make_graphed_callables_with_kwargs( def test_make_graphed_callables_with_kwargs(
...@@ -687,7 +712,10 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( ...@@ -687,7 +712,10 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
optimizer.step() optimizer.step()
outputs = [y for _, y in sorted(outputs.items())] outputs = [y for _, y in sorted(outputs.items())]
return get_outputs(model, outputs) outputs = get_outputs(model, outputs)
if with_graph:
reset_graphs(layer_forwards)
return outputs
def test_make_graphed_callables_with_interleaved_pipeline_parallelism( def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment