Commit c1a1c04e authored by wenjh's avatar wenjh
Browse files

Merge nv_main(2.10) to main


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents e698a0a7 66aed3ae
...@@ -45,7 +45,8 @@ from transformer_engine.pytorch.utils import ( ...@@ -45,7 +45,8 @@ from transformer_engine.pytorch.utils import (
) )
from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.utils import get_cudnn_version
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.quantized_tensor import ( from transformer_engine.pytorch.quantized_tensor import (
Quantizer,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
...@@ -60,8 +61,16 @@ from utils import ( ...@@ -60,8 +61,16 @@ from utils import (
get_available_attention_backends, get_available_attention_backends,
) )
# Check if hardware supports FP8 # Check if hardware supports FP8 attention.
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
fp8_attn_available, reason_for_no_fp8_attn = fp8_available, reason_for_no_fp8
device_compute_capability = get_device_compute_capability()
if fp8_available and (device_compute_capability < (9, 0) or device_compute_capability >= (12, 0)):
fp8_attn_available = False
reason_for_no_fp8_attn = (
"FP8 attention is not supported for compute capability ="
f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}"
)
# Reset RNG seed and states # Reset RNG seed and states
seed = 1234 seed = 1234
...@@ -130,6 +139,11 @@ def test_dot_product_attention( ...@@ -130,6 +139,11 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa: if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2] config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
config.attn_mask_type = (
"padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
)
# Get backends # Get backends
is_training = True is_training = True
...@@ -171,7 +185,7 @@ def test_dot_product_attention( ...@@ -171,7 +185,7 @@ def test_dot_product_attention(
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
if unfused_attn_supported: if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( unfused_attn_fwd, unfused_max_logit, unfused_attn_bwd = _run_dot_product_attention(
dtype, dtype,
config, config,
"UnfusedDotProductAttention", "UnfusedDotProductAttention",
...@@ -185,7 +199,7 @@ def test_dot_product_attention( ...@@ -185,7 +199,7 @@ def test_dot_product_attention(
# FusedAttention backend # FusedAttention backend
if fused_attn_supported: if fused_attn_supported:
if len(fused_attn_backends) == 1: if len(fused_attn_backends) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_max_logit, fused_attn_bwd = _run_dot_product_attention(
dtype, dtype,
config, config,
"FusedAttention", "FusedAttention",
...@@ -197,7 +211,7 @@ def test_dot_product_attention( ...@@ -197,7 +211,7 @@ def test_dot_product_attention(
) )
if len(fused_attn_backends) == 2: if len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention(
dtype, dtype,
config, config,
"FusedAttention", "FusedAttention",
...@@ -208,7 +222,7 @@ def test_dot_product_attention( ...@@ -208,7 +222,7 @@ def test_dot_product_attention(
is_training, is_training,
) )
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention(
dtype, dtype,
config, config,
"FusedAttention", "FusedAttention",
...@@ -221,7 +235,7 @@ def test_dot_product_attention( ...@@ -221,7 +235,7 @@ def test_dot_product_attention(
# FlashAttention backend # FlashAttention backend
if flash_attn_supported: if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( flash_attn_fwd, _, flash_attn_bwd = _run_dot_product_attention(
dtype, dtype,
config, config,
"FlashAttention", "FlashAttention",
...@@ -242,6 +256,8 @@ def test_dot_product_attention( ...@@ -242,6 +256,8 @@ def test_dot_product_attention(
if unfused_attn_supported and fused_attn_supported: if unfused_attn_supported and fused_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs fused attn") logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
if config.return_max_logit:
torch.testing.assert_close(fused_max_logit, unfused_max_logit, **tols)
for i, _ in enumerate(unfused_attn_bwd): for i, _ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported: if fused_attn_supported and flash_attn_supported:
...@@ -265,6 +281,33 @@ def test_dpa_checkpoint(dtype, model_configs, model): ...@@ -265,6 +281,33 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
model_configs_max_logit = {
# test: ModelConfig(b, sq, hq, dqk)
"max_logit_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
"max_logit_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
"max_logit_3": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
"max_logit_4": ModelConfig(
8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias"
),
"max_logit_5": ModelConfig(
8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0)
),
"max_logit_6": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_max_logit])
@pytest.mark.parametrize("model", model_configs_max_logit.keys())
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"])
def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with checkpointing"""
config = model_configs[model]
config.return_max_logit = True
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
model_configs_softmax = { model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk) # test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8), "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
...@@ -962,6 +1005,8 @@ def _run_dot_product_attention( ...@@ -962,6 +1005,8 @@ def _run_dot_product_attention(
layout = layout.replace("d", "dqk") layout = layout.replace("d", "dqk")
tensor_shape = [dim_to_num[j] for j in layout.split("_")] tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda") tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
# tensor: with padding tokens
# tensor_orig: without padding tokens
tensor_orig = tensor tensor_orig = tensor
if qkv_format == "thd" and pad_between_seqs: if qkv_format == "thd" and pad_between_seqs:
tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
...@@ -1071,6 +1116,7 @@ def _run_dot_product_attention( ...@@ -1071,6 +1116,7 @@ def _run_dot_product_attention(
layer_number=1, layer_number=1,
attention_type=config.attn_type, attention_type=config.attn_type,
softmax_type=config.softmax_type, softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit,
).to(dtype=dtype, device="cuda") ).to(dtype=dtype, device="cuda")
if not is_training: if not is_training:
block = block.eval() block = block.eval()
...@@ -1108,16 +1154,21 @@ def _run_dot_product_attention( ...@@ -1108,16 +1154,21 @@ def _run_dot_product_attention(
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
fast_zero_fill=True, fast_zero_fill=True,
) )
max_logit = None
if config.return_max_logit:
out, max_logit = out
if is_training: if is_training:
out.backward(d_out) out.backward(d_out)
d_softmax_offset = None d_softmax_offset = None
if is_training and config.softmax_type != "vanilla": if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad d_softmax_offset = block.softmax_offset.grad
if backend in ["FlashAttention", "UnfusedDotProductAttention"]: if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if is_training: if is_training:
return out, (q.grad, k.grad, v.grad, d_softmax_offset) return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
else: else:
return out, (None, None, None, d_softmax_offset) return out, max_logit, (None, None, None, d_softmax_offset)
if backend == "FusedAttention": if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs: if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
...@@ -1146,14 +1197,18 @@ def _run_dot_product_attention( ...@@ -1146,14 +1197,18 @@ def _run_dot_product_attention(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
) )
if is_training: if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset) return (
out_orig,
max_logit,
(q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset),
)
else: else:
return out_orig, (None, None, None, d_softmax_offset) return out_orig, max_logit, (None, None, None, d_softmax_offset)
else: else:
if is_training: if is_training:
return out, (q.grad, k.grad, v.grad, d_softmax_offset) return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
else: else:
return out, (None, None, None, d_softmax_offset) return out, max_logit, (None, None, None, d_softmax_offset)
model_configs_te_layer = { model_configs_te_layer = {
...@@ -1527,8 +1582,7 @@ model_configs_fp8_extra_state = { ...@@ -1527,8 +1582,7 @@ model_configs_fp8_extra_state = {
} }
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"]) @pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
...@@ -1690,8 +1744,7 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] ...@@ -1690,8 +1744,7 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16) @pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
...@@ -1927,8 +1980,7 @@ def _run_mha_fp8_vs_f16( ...@@ -1927,8 +1980,7 @@ def _run_mha_fp8_vs_f16(
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
...@@ -2256,8 +2308,7 @@ models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"] ...@@ -2256,8 +2308,7 @@ models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
), ),
reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""", reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
) )
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8) @pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0) @pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model): def test_custom_mha_fp8_vs_f16(dtype, model):
......
...@@ -138,8 +138,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -138,8 +138,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn = { model_configs_fused_attn = {
# test: ModelConfig(b, sq, hq, dqk) # test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_logit=True), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_logit=True), # MHA
"cp_1_2": ModelConfig( "cp_1_2": ModelConfig(
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), # MHA ), # MHA
...@@ -184,7 +184,7 @@ dtypes = ["bf16", "fp16", "fp8"] ...@@ -184,7 +184,7 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats = ["bshd", "sbhd", "thd"] qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential: if test_essential:
configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"] dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"] qkv_formats = ["sbhd", "thd"]
......
...@@ -685,11 +685,12 @@ if __name__ == "__main__": ...@@ -685,11 +685,12 @@ if __name__ == "__main__":
pass pass
else: else:
test_log_expert_parallel() test_log_expert_parallel()
if fp8_available:
for parallel_mode in ["column", "row"]: for parallel_mode in ["column", "row"]:
for gather_weight in [True, False]: for gather_weight in [True, False]:
test_log_distributed(parallel_mode, gather_weight) test_log_distributed(parallel_mode, gather_weight)
if fp8_available:
for parallel_mode in ["row", "column"]: for parallel_mode in ["row", "column"]:
test_disable_fp8_layer(parallel_mode) test_disable_fp8_layer(parallel_mode)
......
test_switch_to_nondebug_mode:
enabled: True
layers:
layer_name_regex_pattern: .*
transformer_engine:
TestDummyFeature:
enabled: True
inspect_only_once: True
tensors: [weight, activation, gradient, output, wgrad, dgrad]
gemms: [wgrad, dgrad, fprop]
...@@ -18,7 +18,11 @@ from transformer_engine.pytorch import ( ...@@ -18,7 +18,11 @@ from transformer_engine.pytorch import (
) )
from transformer_engine.pytorch.quantization import RecipeState from transformer_engine.pytorch.quantization import RecipeState
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.debug.features.utils.stats_computation import (
compute_max_blockwise_dynamic_range,
BlockwiseDynamicRangeStat,
)
import math
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
...@@ -154,7 +158,7 @@ fp8_recipes = [ ...@@ -154,7 +158,7 @@ fp8_recipes = [
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_numerics(fp8_recipe, feature_dirs): def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs):
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling(): if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling():
...@@ -210,6 +214,107 @@ def test_numerics(fp8_recipe, feature_dirs): ...@@ -210,6 +214,107 @@ def test_numerics(fp8_recipe, feature_dirs):
assert overflows == pytest.approx(expected.cpu(), abs=1e-4) assert overflows == pytest.approx(expected.cpu(), abs=1e-4)
LOG_HIGH_PRECISION_CONFIG = """
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
stats:
- dynamic_range
- max_blockwise_dynamic_range:
block_size: 4
dims: 1
- max_blockwise_dynamic_range:
block_size: 4
dims: 2
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
@pytest.mark.parametrize("tensor_name", ["activation", "weight", "gradient"])
def test_log_stats_numerics(feature_dirs, tensor_name):
"""Check correctness of dynamic range and max blockwise dynamic range stats.
Tests different tensor types:
- activation/weight: use both orientations (rowwise + columnwise), takes max
- gradient/dgrad: use single orientation (rowwise only)
"""
log_only_bare_stats_config = LOG_HIGH_PRECISION_CONFIG
with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir:
# There is 1024 x 1024 tensor with very small epsilon values in almost all elements,
# one row of large value A and three rows of large value B.
epsilon = 1e-10
A = 1000
B = 50
tensor = torch.zeros(1024, 1024).cuda() + epsilon
tensor[0, :] = A
tensor[1:4, :] = B
debug_api.transformer_engine.inspect_tensor(
layer_name="layer_name",
tensor_name=tensor_name,
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=None,
rowwise_quantized_tensor=None,
columnwise_quantized_tensor=None,
)
debug_api.step()
output = read_log(log_dir)
max_over_orientations = tensor_name in ["activation", "weight"]
max_over_orientations_suffix = "_max_over_orientations" if max_over_orientations else ""
# Track which stats were found to ensure all are present
found_dims_1 = False
found_dims_2 = False
found_dynamic_range = False
for line in output.splitlines():
if f"max_blockwise_dynamic_range_block_size_4_dims_1{max_over_orientations_suffix}" in line:
max_blockwise_dynamic_range_block_size_4_dims_1 = float(line.split("value=")[1])
if max_over_orientations:
# Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B)
expected = math.log2(A) - math.log2(B)
else:
# Rowwise blocks have uniform values -> dynamic_range = 0
expected = 0
assert max_blockwise_dynamic_range_block_size_4_dims_1 == pytest.approx(
expected, abs=1e-4
)
found_dims_1 = True
elif (
f"max_blockwise_dynamic_range_block_size_4_dims_2{max_over_orientations_suffix}" in line
):
max_blockwise_dynamic_range_block_size_4_dims_2 = float(line.split("value=")[1])
# For 2D blocks (4x4 tiles), blocks always contain mixed values from different rows
expected = math.log2(A) - math.log2(B)
assert max_blockwise_dynamic_range_block_size_4_dims_2 == pytest.approx(
expected, abs=1e-4
)
found_dims_2 = True
elif "_dynamic_range" in line and "max_blockwise_dynamic_range" not in line:
dynamic_range = float(line.split("value=")[1])
expected = math.log2(A) - math.log2(epsilon)
assert dynamic_range == pytest.approx(expected, abs=1e-4)
found_dynamic_range = True
# Ensure all expected stats were found in the output
assert found_dims_1, "max_blockwise_dynamic_range (dims=1) not found in output"
assert found_dims_2, "max_blockwise_dynamic_range (dims=2) not found in output"
assert found_dynamic_range, "dynamic_range not found in output"
@pytest.mark.parametrize("layer", ["linear", "transformer"]) @pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
if not fp8_available: if not fp8_available:
...@@ -256,3 +361,92 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): ...@@ -256,3 +361,92 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
debug_api.end_debug() debug_api.end_debug()
TEDebugState._reset() TEDebugState._reset()
def test_compute_max_blockwise_dynamic_range_direct():
"""Direct unit test for compute_max_blockwise_dynamic_range function.
Tests the function with various configurations to ensure correct behavior
for different block sizes, dimensions, and orientation settings.
"""
# Create test tensor with uniform rows but mixed columns
# Row 0: all 1000, Row 1-3: all 50, remaining: all 0.01
epsilon = 0.01
A = 1000.0
B = 50.0
tensor = torch.zeros(1024, 1024).cuda() + epsilon
tensor[0, :] = A
tensor[1:4, :] = B
# Test 1: dims=1, max_over_orientations=False (rowwise only)
# Rowwise blocks have uniform values -> dynamic_range should be 0
stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=False)
result = compute_max_blockwise_dynamic_range(tensor, stat_config)
assert result.item() == pytest.approx(
0.0, abs=1e-4
), "Rowwise 1D blocks with uniform values should have dynamic_range=0"
# Test 2: dims=1, max_over_orientations=True (max of rowwise and columnwise)
# Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B)
stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True)
result = compute_max_blockwise_dynamic_range(tensor, stat_config)
expected = math.log2(A) - math.log2(B)
assert result.item() == pytest.approx(expected, abs=1e-4), (
f"Max over orientations should capture columnwise dynamic_range, expected {expected}, got"
f" {result.item()}"
)
# Test 3: dims=2, block_size=4 (4x4 tiles)
# 2D blocks span multiple rows -> always have mixed values
stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=2, max_over_orientations=False)
result = compute_max_blockwise_dynamic_range(tensor, stat_config)
expected = math.log2(A) - math.log2(B)
assert result.item() == pytest.approx(expected, abs=1e-4), (
f"2D blocks should capture mixed values from different rows, expected {expected}, got"
f" {result.item()}"
)
# Test 4: Different block size
# With block_size=8, columnwise blocks contain [A, B, B, B, epsilon, epsilon, epsilon, epsilon]
# So max=A, min=epsilon (not B anymore)
stat_config = BlockwiseDynamicRangeStat(block_size=8, dims=1, max_over_orientations=True)
result = compute_max_blockwise_dynamic_range(tensor, stat_config)
expected = math.log2(A) - math.log2(epsilon) # min is epsilon, not B
assert result.item() == pytest.approx(
expected, abs=1e-4
), f"Block size 8 should work correctly, expected {expected}, got {result.item()}"
# Test 5: Tensor with all uniform values -> dynamic_range should be 0
uniform_tensor = torch.ones(64, 64).cuda() * 42.0
stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True)
result = compute_max_blockwise_dynamic_range(uniform_tensor, stat_config)
assert result.item() == pytest.approx(
0.0, abs=1e-4
), "Uniform tensor should have dynamic_range=0"
# Test 6: 3D tensor flattening validation using 2D/3D comparison
# Create a 4x4 tensor with distinct 2x2 blocks, compute with dims=2, block_size=2
# Then reshape to 3D and compute again - results should match if flattening is correct
tensor_2d = torch.tensor(
[
[1.0, 1.0, 10.0, 10.0],
[1.0, 1.0, 10.0, 10.0],
[100.0, 100.0, 1000.0, 1000.0],
[100.0, 100.0, 1000.0, 1000.0],
]
).cuda()
# Compute on 2D tensor: 4 blocks of 2x2, max range is log2(1000/100)
stat_config = BlockwiseDynamicRangeStat(block_size=2, dims=2, max_over_orientations=False)
result_2d = compute_max_blockwise_dynamic_range(tensor_2d, stat_config)
# Reshape to 3D [2, 2, 4] and compute - should give same result if flattening is correct
tensor_3d = tensor_2d.reshape(2, 2, 4)
result_3d = compute_max_blockwise_dynamic_range(tensor_3d, stat_config)
assert result_2d.item() == pytest.approx(result_3d.item(), abs=1e-6), (
"3D tensor [2,2,4] flattened to [4,4] must give same result as original 2D, got"
f" 2D={result_2d.item()}, 3D={result_3d.item()}"
)
print("All direct tests for compute_max_blockwise_dynamic_range passed!")
...@@ -6,71 +6,70 @@ ...@@ -6,71 +6,70 @@
import pytest import pytest
import torch import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import time
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs): @pytest.mark.parametrize("use_microbatching", [False, True])
debug_api.end_debug() def test_layer_switches_to_nondebug_mode(configs_dir, feature_dirs, use_microbatching):
TEDebugState._reset() """
if debug_tools_initialized: Test that layers switch to non-debug mode when no features are active.
# This config log stats starting from 0, every N iterations for huge N >> NUM_ITERS.
# So after 1 warm-up iteration, this layers should work in non-debug mode.
debug_api.initialize(
config_file=configs_dir + "/perf_config.yaml", feature_dirs=feature_dirs
)
try: Uses TestDummyFeature with inspect_only_once=True, which makes inspect_tensor_enabled return (False, None).
if layer == "linear": The TE should:
model = torch.nn.Sequential( 1. Call inspect_tensor_enabled to check if feature is needed
te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2") 2. Never call inspect_tensor
).cuda() 3. Allow layers to switch to non-debug mode for optimal performance,
NUM_ITERS = 18000 so that inspect_tensor_enabled is never called again.
elif layer == "transformer":
model = torch.nn.Sequential(
te.TransformerLayer(1, 1, 1, name="transformer1"),
te.TransformerLayer(1, 1, 1, name="transformer2"),
).cuda()
NUM_ITERS = 2000
x = torch.randn(1, 1, 1).cuda() Tests both with and without microbatching to ensure proper behavior in both scenarios.
"""
try:
debug_api.initialize(
config_file=configs_dir + "/test_switch_to_nondebug_mode.yaml",
feature_dirs=feature_dirs,
)
import transformer_engine.debug.features._test_dummy_feature as dummy_feature
# Reset counters
dummy_feature._inspect_tensor_enabled_call_count = 0
dummy_feature._inspect_tensor_call_count = 0
model = te.Linear(256, 256, name="test_linear").cuda()
x = torch.randn(8, 256, 256).cuda()
# Run multiple iterations
for i in range(20):
if use_microbatching:
# Alternate between first and non-first microbatch
is_first_microbatch = i % 2 == 0
y = model(x, is_first_microbatch=is_first_microbatch)
else:
# Run without specifying is_first_microbatch
y = model(x) y = model(x)
y.sum().backward() y.sum().backward()
debug_api.step() debug_api.step()
torch.cuda.synchronize()
time_start = time.time() # Verify inspect_tensor_enabled was called only once per tensor
for i in range(NUM_ITERS): # (activation, weight, gradient, output, wgrad, dgrad)
y = model(x) enabled_call_count = dummy_feature._inspect_tensor_enabled_call_count
y.sum().backward() microbatch_info = "with microbatching" if use_microbatching else "without microbatching"
if debug_tools_initialized: assert enabled_call_count == 6, (
debug_api.step() f"inspect_tensor_enabled was called {enabled_call_count} times ({microbatch_info}), "
torch.cuda.synchronize() "but should be called 6 times to check if feature is needed for each tensor "
time_end = time.time() "(activation, weight, gradient, output, wgrad, dgrad)"
)
# Verify inspect_tensor was never called - it should not be called if inspect_tensor_enabled returns (False, None)
inspect_call_count = dummy_feature._inspect_tensor_call_count
assert inspect_call_count == 0, (
f"inspect_tensor was called {inspect_call_count} times ({microbatch_info}), "
"but should never be called when inspect_tensor_enabled returns (False, None)"
)
finally: finally:
if debug_tools_initialized:
debug_api.end_debug() debug_api.end_debug()
TEDebugState._reset()
return time_end - time_start
@pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_cpu_overhead(layer, configs_dir, feature_dirs):
# runs one layer many times on very small tensor
# - gpu time should be negligible, so time should be dominated by cpu time.
# if layers does not invoke any feature in current iteration,
# then it changed into non-debug mode and should not have any non-negligible cpu overhead
# compared to layer without debug tools initialized.
with_debug_tools = _run_cpu_overhead(True, layer, configs_dir, feature_dirs)
without_debug_tools = _run_cpu_overhead(False, layer, configs_dir, feature_dirs)
print(f"with_debug_tools: {with_debug_tools} s")
print(f"without_debug_tools: {without_debug_tools} s")
assert with_debug_tools < without_debug_tools * 1.25 # 25% overhead margin
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import argparse
import datetime
import os
import sys
import torch
from torch import nn
import torch.distributed as dist
from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Float8BlockScaling,
Format,
Recipe,
)
import transformer_engine.pytorch as te
from transformer_engine.pytorch import (
QuantizedTensor,
Float8Tensor,
Float8BlockwiseQTensor,
)
from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8
from transformer_engine.pytorch.tensor.utils import replace_raw_data
def _get_raw_data(quantized_tensor):
"""Get the underlying data of a quantized tensor, used in zero-1 optimizer"""
if isinstance(quantized_tensor, Float8Tensor):
assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute"
assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8"
return quantized_tensor._data
elif isinstance(quantized_tensor, Float8BlockwiseQTensor):
assert hasattr(
quantized_tensor, "_rowwise_data"
), "Float8BlockwiseQTensor does not have _rowwise_data attribute"
assert (
quantized_tensor._rowwise_data.dtype == torch.uint8
), "Float8BlockwiseQTensor _rowwise_data must be uint8"
return quantized_tensor._rowwise_data
else:
raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}")
class MiniZero_1:
"""A mini zero-1 optimizer implementation, just used for this test"""
def __init__(self, weights, lr, dp_group):
self.rank = dist.get_rank(dp_group)
self.world_size = dist.get_world_size(dp_group)
self.weights = weights
self.lr = lr
self.dp_group = dp_group
# [self.offsets[i], self.offsets[i+1]) is the range of weights[i] in the global buffer
self.offsets = [0]
for weight in self.weights:
self.offsets.append(self.offsets[-1] + weight.numel())
# Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may
# not be the end range of the last weight.
if self.offsets[-1] % self.world_size != 0:
self.offsets[-1] += self.world_size - self.offsets[-1] % self.world_size
self.master_weights = []
# The start offset of the master weight in the weight
self.start_offsets = []
# The overlapping area of the weight and this rank's local buffer
self.overlapping_areas = []
# The start and end of this rank's local buffer in the global buffer
rank_start = self.offsets[-1] // self.world_size * self.rank
rank_end = rank_start + self.offsets[-1] // self.world_size
for weight, offset in zip(self.weights, self.offsets[:-1]):
if offset >= rank_end or (offset + weight.numel()) <= rank_start:
# This weight is not in this rank's local buffer
master_weight = None
start_offset = None
overlapping_area = None
else:
overlapping_start = max(rank_start, offset)
overlapping_end = min(rank_end, offset + weight.numel())
length = overlapping_end - overlapping_start
start_offset = overlapping_start - offset
if isinstance(weight, QuantizedTensor):
# If weight is a FP8 tensor, we need to use the original high precision version
# to initialize the master weight.
high_precision_init_val = weight.get_high_precision_init_val().view(-1)
master_weight = high_precision_init_val.to(weight.device).float()[
start_offset : start_offset + length
]
else:
master_weight = (
weight.detach().view(-1).float()[start_offset : start_offset + length]
)
overlapping_area = (overlapping_start, overlapping_end)
self.master_weights.append(master_weight)
self.start_offsets.append(start_offset)
self.overlapping_areas.append(overlapping_area)
# Create global buffer for grads reduce-scatter
self.grad_buffer = torch.empty(
[self.offsets[-1]], dtype=torch.float32, device=weights[0].device
)
self.grad_buffer_slice = self.grad_buffer[rank_start:rank_end]
# Create global buffer for weights all-gather
if isinstance(self.weights[0], QuantizedTensor):
weight_buffer_dtype = torch.uint8
else:
weight_buffer_dtype = weights[0].dtype
self.weight_buffer = torch.empty(
[self.offsets[-1]], dtype=weight_buffer_dtype, device=weights[0].device
)
self.weight_buffer_slice = self.weight_buffer[rank_start:rank_end]
def step(self):
# -----------------------------------------------------------------------------------------
# Step 1: Copy grads to the grad buffer
# -----------------------------------------------------------------------------------------
for weight, offset in zip(self.weights, self.offsets[:-1]):
start = offset
end = offset + weight.numel()
self.grad_buffer[start:end].copy_(weight.main_grad.view(-1))
# -----------------------------------------------------------------------------------------
# Step 2: Grads reduce-scatter
# -----------------------------------------------------------------------------------------
# Don't use reduce_scatter directly to explicitly control the reduce order.
# dist.reduce_scatter_tensor(self.grad_buffer_slice, self.grad_buffer, op=dist.ReduceOp.AVG,
# group=self.dp_group)
buffers = [torch.empty_like(self.grad_buffer) for _ in range(self.world_size)]
dist.all_gather(buffers, self.grad_buffer, group=self.dp_group)
for i in range(1, self.world_size):
buffers[0] += buffers[i]
rank_start = self.offsets[-1] // self.world_size * self.rank
rank_end = rank_start + self.offsets[-1] // self.world_size
self.grad_buffer_slice.copy_(buffers[0][rank_start:rank_end])
self.grad_buffer_slice /= self.world_size
# -----------------------------------------------------------------------------------------
# Step 3: Update master weights
# -----------------------------------------------------------------------------------------
for master_weight, overlapping_area in zip(self.master_weights, self.overlapping_areas):
if master_weight is None:
# This weight's master weight is in other rank.
continue
grad = self.grad_buffer[overlapping_area[0] : overlapping_area[1]]
master_weight -= grad * self.lr
# -----------------------------------------------------------------------------------------
# Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight
# -----------------------------------------------------------------------------------------
if isinstance(self.weights[0], QuantizedTensor):
# FP8 weights case
for i in range(1, len(self.weights)):
assert isinstance(self.weights[i], QuantizedTensor)
cast_master_weights_to_fp8(
self.weights, self.master_weights, self.start_offsets, self.dp_group
)
else:
# BF16 weights case
for weight, master_weight, start_offset in zip(
self.weights, self.master_weights, self.start_offsets
):
if master_weight is None:
continue
start = start_offset
end = start_offset + master_weight.numel()
weight.data.view(-1)[start:end].copy_(master_weight)
# -----------------------------------------------------------------------------------------
# Step 5: Copy the updated weights (not all weights) to the weight buffer
# -----------------------------------------------------------------------------------------
for i in range(len(self.weights)):
master_weight = self.master_weights[i]
if master_weight is None:
continue
start_offset = self.start_offsets[i]
if isinstance(self.weights[i], QuantizedTensor):
weight = _get_raw_data(self.weights[i])
else:
weight = self.weights[i]
weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()]
overlapping_start, overlapping_end = self.overlapping_areas[i]
self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice)
# -----------------------------------------------------------------------------------------
# Step 6: Weight all-gather (FP8 or BF16)
# -----------------------------------------------------------------------------------------
dist.all_gather_into_tensor(
self.weight_buffer, self.weight_buffer_slice, group=self.dp_group
)
# -----------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# -----------------------------------------------------------------------------------------
for weight, offset in zip(self.weights, self.offsets[:-1]):
start = offset
end = offset + weight.numel()
if isinstance(weight, QuantizedTensor):
weight = _get_raw_data(weight)
weight.view(-1).data.copy_(self.weight_buffer[start:end])
class MiniOptimizer:
def __init__(self, weights, lr, dp_group):
self.world_size = dist.get_world_size(dp_group)
self.weights = weights
self.lr = lr
self.dp_group = dp_group
master_weights = []
for weight in self.weights:
master_weights.append(weight.detach().float())
self.master_weights = master_weights
def step(self):
for weight, master_weight in zip(self.weights, self.master_weights):
main_grad = weight.main_grad
# Don't use all-reduce directly to explicitly control the reduce order.
# dist.all_reduce(main_grad, op=dist.ReduceOp.AVG, group=self.dp_group)
buffers = [torch.empty_like(main_grad) for _ in range(self.world_size)]
dist.all_gather(buffers, main_grad, group=self.dp_group)
for i in range(1, self.world_size):
buffers[0] += buffers[i]
main_grad.copy_(buffers[0])
main_grad /= self.world_size
master_weight -= main_grad * self.lr
weight.data.copy_(master_weight)
class MiniFSDP:
def __init__(self, weights, lr, dp_group):
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
self.weights = weights
self.lr = lr
self.dp_group = dp_group
# Flatten the weights and pad to align with world size
raw_data_list = [
_get_raw_data(w).view(-1) if isinstance(w, QuantizedTensor) else w.view(-1)
for w in weights
]
if isinstance(weights[0], QuantizedTensor):
raw_data_list = [_get_raw_data(w).view(-1) for w in weights]
else:
raw_data_list = [w.view(-1) for w in weights]
self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list)
# Split flattened weights into shards
self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank]
self.local_main_grad_shard = torch.zeros_like(self.local_weight_shard)
shard_size = self.flatten_weight.size(0) // world_size
# Map original tensors to flattened indices
tensor_indices = []
cumulative_length = 0
for tensor in raw_data_list:
length = tensor.size(0)
tensor_indices.append((cumulative_length, cumulative_length + length))
cumulative_length += length
# Build shard index mappings
self.weight_indices = []
self.shard_indices = []
for idx, (start, end) in enumerate(tensor_indices):
shard_start = rank * shard_size
shard_end = shard_start + shard_size
adjusted_end = min(shard_end, original_length)
if start <= adjusted_end and end >= shard_start:
start_idx = max(start, shard_start)
end_idx = min(end, adjusted_end)
self.weight_indices.append((start_idx - start, end_idx - start))
self.shard_indices.append((start_idx - shard_start, end_idx - shard_start))
else:
self.weight_indices.append((None, None))
self.shard_indices.append((None, None))
if isinstance(weights[idx], QuantizedTensor):
replace_raw_data(
weights[idx], self.flatten_weight[start:end].view(weights[idx].shape)
)
else:
weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape)
# Initialize local model weights and high-precision master weights
self.local_weights = []
self.master_weights = []
for i, weight in enumerate(self.weights):
weight_start, weight_end = self.weight_indices[i]
shard_start, shard_end = self.shard_indices[i]
if shard_start is not None and shard_end is not None:
local_weight_shard = self.local_weight_shard[shard_start:shard_end]
self.local_weights.append(local_weight_shard)
if isinstance(weight, QuantizedTensor):
high_precision_init_val = weight.get_high_precision_init_val().view(-1)
master_weight_shard = high_precision_init_val.to(weight.device).float()[
weight_start:weight_end
]
else:
master_weight_shard = weight.detach().view(-1).float()[weight_start:weight_end]
self.master_weights.append(master_weight_shard)
else:
self.local_weights.append(None)
self.master_weights.append(None)
setattr(
weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda")
)
def _flatten_tensors_with_pad(self, tensors):
"""
Flatten the list of tensors and pad them to align with the world size.
Args:
tensors (list): List of tensors to flatten.
Returns:
tuple: Flattened tensor and its original length before padding.
"""
world_size = dist.get_world_size(self.dp_group)
flatten_tensor = torch.cat(tensors)
original_length = flatten_tensor.size(0)
padding_needed = (world_size - original_length % world_size) % world_size
if padding_needed > 0:
flatten_tensor = torch.cat(
[flatten_tensor, torch.zeros(padding_needed, dtype=flatten_tensor.dtype)]
)
return flatten_tensor, original_length
def zero_grad(self):
for weight in self.weights:
weight.grad = None
weight.main_grad.zero_()
def step(self):
"""
Perform an optimization step for the distributed sharded model.
This method includes:
1. Gradient reduce-scatter: Synchronize gradients across all processes.
2. Master weight update: Update high-precision master weights using local gradients.
3. Precision casting: Cast updated master weights to FP8 or BF16 precision.
4. Weight synchronization: All-gather updated weights across all processes.
Returns:
None
"""
# Step 1: Reduce-scatter the gradients
main_grad_buffer, _ = self._flatten_tensors_with_pad(
[weight.main_grad.view(-1) for weight in self.weights]
)
main_grad_buffer = main_grad_buffer.to(self.local_main_grad_shard.dtype)
dist.reduce_scatter_tensor(
self.local_main_grad_shard, main_grad_buffer, group=self.dp_group
)
# Step 2: Update the master weights
for weight, master_weight, (shard_start, shard_end) in zip(
self.weights, self.master_weights, self.shard_indices
):
if master_weight is None:
continue
# Extract the local gradient shard for this weight
grad = self.local_main_grad_shard[shard_start:shard_end]
# Update the master weight using gradient descent
master_weight -= grad * self.lr
# Step 3: Cast master weights to FP8 or BF16 precision
if isinstance(self.weights[0], QuantizedTensor):
local_weights = []
for local_weight in self.local_weights:
if local_weight is None:
local_weights.append(None)
continue
local_weights.append(local_weight)
cast_master_weights_to_fp8(
self.weights,
self.master_weights,
[idx[0] for idx in self.weight_indices],
self.dp_group,
local_weights,
)
else:
for weight, master_weight in zip(self.local_weights, self.master_weights):
if master_weight is None:
continue
# Copy updated master weights to local weights
weight.data.copy_(master_weight)
# Step 4: All-gather updated weights across processes
dist.all_gather_into_tensor(
self.flatten_weight, self.local_weight_shard, group=self.dp_group
)
def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
# Configuration constants
NUM_STEPS = 100
SEED = 12345
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)]
mock_group = mock_groups[rank]
linear_kwargs = {
"params_dtype": torch.bfloat16,
"bias": False,
"fuse_wgrad_accumulation": False,
}
# Create model with FP8 weights
with te.quantized_model_init(
enabled=quantization is not None,
recipe=quantization_recipe(quantization),
preserve_high_precision_init_val=True,
):
model_fp8 = nn.Sequential(
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Create model with BF16 weights
model = nn.Sequential(
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Make sure the BF16 model and FP8 model have the same initial weights
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
high_precision_init_val = w_fp8.get_high_precision_init_val()
w.data.copy_(high_precision_init_val)
optimizer_fp8 = MiniFSDP([w for w in model_fp8.parameters()], 10.0, dp_group)
optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group)
for _ in range(100):
optimizer_fp8.zero_grad()
optimizer.zero_grad()
inputs = [
torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank]
with te.autocast(
enabled=quantization is not None,
recipe=quantization_recipe(quantization),
amax_reduction_group=mock_group,
):
y_fp8 = model_fp8(x)
with te.autocast(
enabled=quantization is not None,
recipe=quantization_recipe(quantization),
amax_reduction_group=mock_group,
):
y = model(x)
targets = [torch.randn_like(y) for _ in range(world_size)]
# Choose based on rank to make sure the targets of different ranks are different.
target = targets[rank]
loss_fp8 = nn.MSELoss()(y_fp8, target)
loss = nn.MSELoss()(y, target)
loss_fp8.backward()
loss.backward()
optimizer_fp8.step()
optimizer.step()
torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
print(
f"✅ Successfully validated FSDP {NUM_STEPS} training steps with"
f" {quantization} quantization"
)
def _test_zero_1(dp_group):
"""Make sure the implementation of zero-1 optimizer is correct"""
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
torch.manual_seed(12345)
torch.cuda.manual_seed(12345)
weights = [
torch.randn(256 * 256, dtype=torch.bfloat16, device="cuda"),
torch.randn(256 * 256 * 3, dtype=torch.bfloat16, device="cuda"),
torch.randn(256 * 256 * 2 - 1, dtype=torch.bfloat16, device="cuda"),
]
weights_1 = weights
weights_2 = [weight.clone() for weight in weights]
lr = 1.0
optimizer_1 = MiniZero_1(weights_1, lr, dp_group)
optimizer_2 = MiniOptimizer(weights_2, lr, dp_group)
for _ in range(100):
for w1, w2 in zip(weights_1, weights_2):
main_grads = [
torch.randn_like(w1, dtype=torch.float32, device="cuda") for _ in range(world_size)
]
# Choose based on rank to make sure the grads of different ranks are different.
main_grad = main_grads[rank]
w1.main_grad = main_grad
w2.main_grad = main_grad
optimizer_1.step()
optimizer_2.step()
for w1, w2 in zip(weights_1, weights_2):
torch.testing.assert_close(w1, w2, atol=0, rtol=0)
def quantization_recipe(quantization) -> Recipe:
"""Quantization recipe setup"""
fp8_format = Format.HYBRID
if quantization == "fp8":
return DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
elif quantization == "fp8_cs":
return Float8CurrentScaling(fp8_format=fp8_format)
elif quantization == "fp8_block":
return Float8BlockScaling(fp8_format=fp8_format)
else:
raise ValueError(f"Unsupported quantization: {quantization}")
def _test_cast_master_weights_to_fp8(quantization, dp_group):
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
torch.manual_seed(12345)
torch.cuda.manual_seed(12345)
mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)]
mock_group = mock_groups[rank]
linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": False}
# Create model with FP8 weights
with te.quantized_model_init(
enabled=quantization is not None,
recipe=quantization_recipe(quantization),
preserve_high_precision_init_val=True,
):
model_fp8 = nn.Sequential(
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Create model with BF16 weights
model = nn.Sequential(
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Make sure the BF16 model and FP8 model have the same initial weights
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
high_precision_init_val = w_fp8.get_high_precision_init_val()
w.data.copy_(high_precision_init_val)
# Allocate main_grads for each weight
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
w_fp8.main_grad = torch.zeros_like(w_fp8, dtype=torch.float32, device="cuda")
w.main_grad = torch.zeros_like(w, dtype=torch.float32, device="cuda")
optimizer_fp8 = MiniZero_1([w for w in model_fp8.parameters()], 10.0, dp_group)
optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group)
for i in range(100):
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
w_fp8.main_grad.zero_()
w.main_grad.zero_()
inputs = [
torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank]
with te.autocast(
enabled=quantization is not None,
recipe=quantization_recipe(quantization),
amax_reduction_group=mock_group,
):
y_fp8 = model_fp8(x)
with te.autocast(
enabled=quantization is not None,
recipe=quantization_recipe(quantization),
amax_reduction_group=mock_group,
):
y = model(x)
targets = [torch.randn_like(y) for _ in range(world_size)]
# Choose based on rank to make sure the targets of different ranks are different.
target = targets[rank]
loss_fp8 = nn.MSELoss()(y_fp8, target)
loss = nn.MSELoss()(y, target)
loss_fp8.backward()
loss.backward()
optimizer_fp8.step()
optimizer.step()
torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
def main(argv=None, namespace=None):
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node
assert LOCAL_SIZE <= torch.cuda.device_count()
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
"timeout": datetime.timedelta(seconds=30),
}
dist_init_kwargs["init_method"] = "env://"
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
torch.cuda.set_device(LOCAL_RANK)
dist.init_process_group(**dist_init_kwargs)
parser = argparse.ArgumentParser()
parser.add_argument(
"--quantization", type=str, default=None, choices=["fp8", "fp8_cs", "fp8_block"]
)
args = parser.parse_args(argv, namespace)
dp_group = dist.new_group(backend="nccl")
_test_zero_1(dp_group)
_test_cast_master_weights_to_fp8(args.quantization, dp_group)
_test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group)
dist.destroy_process_group()
return 0
if __name__ == "__main__":
sys.exit(main())
...@@ -8,58 +8,74 @@ import os ...@@ -8,58 +8,74 @@ import os
import sys import sys
import argparse import argparse
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import (
Format,
DelayedScaling,
Float8CurrentScaling,
MXFP8BlockScaling,
)
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed.tensor import DTensor
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, optim from torch import nn, optim
from torch.distributed import DeviceMesh from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh from torch.distributed.device_mesh import init_device_mesh
from transformer_engine.pytorch import QuantizedTensor
from contextlib import nullcontext from contextlib import nullcontext
import transformer_engine.pytorch as te LOCAL_RANK = None
from transformer_engine.common.recipe import Format, DelayedScaling
class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNet, self).__init__()
self.fc1 = te.Linear(input_size, hidden_size)
self.fc2 = te.Linear(hidden_size, output_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items()}
return custom_attrs
def restore_custom_attrs(module, custom_attrs): def dist_print(msg):
for name, param in module.named_parameters(): if LOCAL_RANK == 0:
if name in custom_attrs: print(msg)
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)
def _parse_args(argv=None, namespace=None): def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()") parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()")
parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model") parser.add_argument("--num-heads", type=int, default=8, help="Number of attn. heads")
parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size") parser.add_argument("--head-dim", type=int, default=64, help="Attention head size")
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") parser.add_argument("--batch-size", type=int, default=16, help="Batch size of input")
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") parser.add_argument("--seq-length", type=int, default=128, help="Sequence length of input")
parser.add_argument("--params-dtype", type=str, default="float32", help="Parameter dtype.")
parser.add_argument( parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
) )
parser.add_argument(
"--recipe",
type=str,
default="mx_fp8_block_scaling",
help="Quantizer type.",
choices=["delayed_scaling", "current_scaling", "mx_fp8_block_scaling"],
)
parser.add_argument(
"--layer-type",
type=str,
default="TransformerLayer",
choices=[
"Linear",
"LayerNormLinear",
"LayerNormMLP",
"MultiheadAttention",
"TransformerLayer",
],
help="Transformer Engine layer type",
)
parser.add_argument("--num-layers", type=int, default=4, help="Number of layers in the model")
parser.add_argument( parser.add_argument(
"--iter", type=int, default=10, help="Number of iterations for forward pass" "--iter", type=int, default=10, help="Number of iterations for forward pass"
) )
parser.add_argument(
"--device",
type=str,
default="meta",
help="Device to run the model on.",
choices=["cuda", "meta"],
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
# Adding hsdp_dim as a list argument, comma-separated # Adding hsdp_dim as a list argument, comma-separated
parser.add_argument( parser.add_argument(
...@@ -74,10 +90,170 @@ def _parse_args(argv=None, namespace=None): ...@@ -74,10 +90,170 @@ def _parse_args(argv=None, namespace=None):
return args return args
sub_modules_to_wrap = [te.Linear] ## Methods to help initialize the TE model in an FSDP2 setting
## with required configurations based on command line args
def get_te_layer_from_string(layer_name):
te_layer_types = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
te_layer_names = [layer.__name__ for layer in te_layer_types]
te_layer_map = dict(zip([name.lower() for name in te_layer_names], te_layer_types))
if layer_name.lower() not in te_layer_map.keys():
raise argparse.ArgumentTypeError(
f'"{layer_name}" is not a valid Transformer Engine layer, '
f"please choose layer from {te_layer_names}."
)
return te_layer_map[layer_name.lower()]
def get_recipe_from_string(recipe, fp8_format=Format.HYBRID):
if recipe == "delayed_scaling":
return DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
elif recipe == "current_scaling":
return Float8CurrentScaling(fp8_format=fp8_format)
elif recipe == "mx_fp8_block_scaling":
return MXFP8BlockScaling(fp8_format=fp8_format)
else:
raise ValueError(f"Unknown quantizer type: {recipe}")
def init_te_model(config):
hidden_size = config.num_heads * config.head_dim
args = [hidden_size, hidden_size]
inp_shape = [config.seq_length, config.batch_size, hidden_size]
out_shape = [config.seq_length, config.batch_size, hidden_size]
if config.params_dtype == "float16":
params_dtype = torch.float16
elif config.params_dtype == "bfloat16":
params_dtype = torch.bfloat16
else:
params_dtype = torch.float32
kwargs = {
"params_dtype": params_dtype,
}
kwargs["device"] = config.device
layer_type = get_te_layer_from_string(config.layer_type)
# We are creating model in a way so that we can test both reshard_after_forward=True/False cases.
# more details below.
if layer_type in [te.MultiheadAttention, te.TransformerLayer]:
# For this case, we are creating a model that resemebles production use-cases
# wherein there are mltiple TransformerLayers in the model. And we would need
# to shard each transformer layer. Since each transformer layer is not a root module,
# FSDP2's fully_shard assigns reshard_after_forward=False for all parameters of the model.
args[1] *= 4 # FFN hidden size
args.append(config.num_heads)
kwargs["fuse_qkv_params"] = True
if layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
model = nn.Sequential(*[layer_type(*args, **kwargs) for _ in range(config.num_layers)])
elif layer_type == te.LayerNormLinear:
# For this case, we are creating a model with just one LayerNormLinear layer
# so that the model itself is a root module, and FSDP2's fully_shard assigns
# reshard_after_forward=True for the parameters of these model.
args[1] *= 3 # QKV projection
out_shape[-1] *= 3
model = layer_type(*args, **kwargs)
else:
model = layer_type(*args, **kwargs)
return model, inp_shape, out_shape
def get_device_mesh(world_size, sharding_dims):
dist_print(f"sharding-dims:{sharding_dims}")
device_ids = list(range(world_size))
if sharding_dims is None: # FSDP
mesh = DeviceMesh("cuda", device_ids)
elif len(sharding_dims) == 1:
assert sharding_dims[0] == world_size
mesh = DeviceMesh("cuda", device_ids)
elif len(sharding_dims) == 2: # HSDP
assert sharding_dims[0] * sharding_dims[1] == world_size
mesh = init_device_mesh(
"cuda",
(sharding_dims[0], sharding_dims[1]),
mesh_dim_names=("replicate", "shard"),
)
else:
assert False
return mesh
def shard_model_with_fsdp2(model, mesh):
for child in model.children():
fully_shard(child, mesh=mesh)
fully_shard(model, mesh=mesh)
return model
#### Methods to save the custom attributes of QuantizedTensors before sharding
#### them with FSDP2, and restore them after sharding.
def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
if isinstance(param, QuantizedTensor):
# Ignore FP8 metadata attributes. Otherwise we will save duplicate copies
# for data/transpose FP8 tensors on top of FP8 tensors that FSDP2 will save.
ignore_keys = [key for key in param.__dict__.keys() if key.startswith("_")]
else:
ignore_keys = []
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items() if k not in ignore_keys}
return custom_attrs
def restore_custom_attrs(module, custom_attrs):
for name, param in module.named_parameters():
if name in custom_attrs:
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)
@torch.no_grad()
def test_fp8_fsdp2_allgather(model):
# Do manual allgather in fp32 and match against fp8 allgather done
# with fsdp2
# FP32 manual weight allgather
fp32_allgathered_params = {}
for name, param in model.named_parameters():
assert isinstance(param, DTensor)
local_tensor = param._local_tensor
device_mesh = param.device_mesh
dist_group = (
device_mesh.get_group(mesh_dim="shard")
if device_mesh.ndim > 1
else device_mesh.get_group()
)
# Perform manual allgather on local_tensor. zeros_like will create hp tensor since torch_dispatch
# for local_tensor will go down the dequantization route.
gathered_tensor = [
torch.zeros_like(local_tensor) for _ in range(dist.get_world_size(group=dist_group))
]
dist.all_gather(gathered_tensor, local_tensor.dequantize(), group=dist_group)
full_tensor = torch.cat(gathered_tensor, dim=0)
fp32_allgathered_params[name] = full_tensor
# FP8 allgather using FSDP2
for module in model.modules():
# Not all modules are wrapped/sharded with FSDP2.
if hasattr(module, "unshard"):
module.unshard()
# Make sure allgathered parameters match exactly
for name, param in model.named_parameters():
assert torch.allclose(param.dequantize(), fp32_allgathered_params[name])
# Revert model to original sharded state
for module in model.modules():
# Not all modules are wrapped/sharded with FSDP2.
if hasattr(module, "reshard"):
module.reshard()
def _train(args): def _train(args):
global LOCAL_RANK
assert "TORCHELASTIC_RUN_ID" in os.environ assert "TORCHELASTIC_RUN_ID" in os.environ
WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
...@@ -103,77 +279,69 @@ def _train(args): ...@@ -103,77 +279,69 @@ def _train(args):
# FP8 Configuration # FP8 Configuration
fp8_format = Format.HYBRID fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") fp8_recipe = get_recipe_from_string(args.recipe, fp8_format)
build_model_context_args = {}
if not args.fp8_init: if not args.fp8_init:
# Build model context (FP8 init) # Build model context (FP8 init)
build_model_context = nullcontext build_model_context = nullcontext
build_model_context_args = {} else:
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch import quantized_model_init
build_model_context = quantized_model_init build_model_context = fp8_model_init
build_model_context_args["enabled"] = True build_model_context_args["enabled"] = True
build_model_context_args["recipe"] = fp8_recipe
# Build the model with the specified context dist_print(f"Memory before model init: {torch.cuda.memory_allocated(device)/1e6} MB")
# Create the model on the meta/cuda device as per args
with build_model_context(**build_model_context_args): with build_model_context(**build_model_context_args):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size) model, inp_shape, out_shape = init_te_model(args)
else: dist_print(
model = SimpleNet(args.input_size, args.hidden_size, args.output_size) f"Memory after model init on device {args.device}:"
# Move the model to the correct device f" {torch.cuda.memory_allocated(device)/1e6} MB"
)
model.to(device)
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...")
# Creating a DeviceMesh for fully_shard # Creating a DeviceMesh for fully_shard
world_size = int(WORLD_SIZE) world_size = int(WORLD_SIZE)
device_ids = list(range(world_size))
if LOCAL_RANK == 0:
print(f"sharding-dims:{args.sharding_dims}")
# Setup the sharding mesh for FSDP/HSDP # Setup the sharding mesh for FSDP/HSDP
if args.sharding_dims == None: # FSDP mesh = get_device_mesh(world_size, args.sharding_dims)
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 1:
assert args.sharding_dims[0] == device_ids[-1] + 1
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 2: # HSDP
assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1
mesh = init_device_mesh(
"cuda",
(args.sharding_dims[0], args.sharding_dims[1]),
mesh_dim_names=("replicate", "shard"),
)
else:
assert False
# Apply FSDP/HSDP
custom_attrs = save_custom_attrs(model) custom_attrs = save_custom_attrs(model)
for sub_module in model.modules(): model = shard_model_with_fsdp2(model, mesh)
if any(
isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap
):
fully_shard(sub_module, mesh=mesh)
fully_shard(model, mesh=mesh)
restore_custom_attrs(model, custom_attrs) restore_custom_attrs(model, custom_attrs)
# model now has DTensors as its parameters
if args.device == "meta":
# After FSDP2 has been applied, materialize and initialize the sharded parameters
# TE base.py's reset_parameters() handles DTensors with FP8 initialization
for module in model.modules():
if hasattr(module, "reset_parameters"):
module.reset_parameters()
dist_print(f" Sharded parameters materialized and initialized on cuda device.")
dist_print(
f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB"
)
optimizer = optim.Adam(model.parameters(), lr=1e-3) optimizer = optim.Adam(model.parameters(), lr=1e-3)
for iteration in range(args.iter): for iteration in range(args.iter):
# Zero the parameter gradients # Zero the parameter gradients
optimizer.zero_grad() optimizer.zero_grad()
input_data = torch.randn(args.batch_size, args.input_size).to(device) input_data = torch.randn(inp_shape).to(device)
with te.autocast(enabled=True, recipe=fp8_recipe):
output = model(input_data) output = model(input_data)
target = torch.randn(args.batch_size, args.output_size).to(device) target = torch.randn(out_shape).to(device)
loss = F.mse_loss(output, target) loss = F.mse_loss(output, target)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if LOCAL_RANK == 0: dist_print(f"Iteration {iteration} completed with loss {loss.item()}")
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")
# Some of the FSDP states are lazy initialized during FSDP forward pass
# so testing fp8 allgather at the end of the training loop.
if args.fp8_init:
test_fp8_fsdp2_allgather(model)
dist.destroy_process_group() dist.destroy_process_group()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Done...")
return 0 return 0
......
...@@ -22,8 +22,8 @@ from transformer_engine.common.recipe import ( ...@@ -22,8 +22,8 @@ from transformer_engine.common.recipe import (
) )
from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.experimental import quantization_nvfp4 from transformer_engine.pytorch.custom_recipes import quantization_nvfp4
from transformer_engine.pytorch.experimental import utils from transformer_engine.pytorch.custom_recipes import utils
from run_layer_with_overlap import _compare_tensors from run_layer_with_overlap import _compare_tensors
...@@ -486,7 +486,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): ...@@ -486,7 +486,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
sequence_parallel (bool): Enable sequence parallelism if True. sequence_parallel (bool): Enable sequence parallelism if True.
kwargs (dict): Additional arguments for the linear layer. kwargs (dict): Additional arguments for the linear layer.
QUANTIZATION options: nvfp4 <=> experimental nvfp4 as a reference QUANTIZATION options: nvfp4 <=> custom nvfp4 as a reference
""" """
params_dtype = torch.bfloat16 params_dtype = torch.bfloat16
use_bias = kwargs.get("bias", True) use_bias = kwargs.get("bias", True)
......
...@@ -2,39 +2,746 @@ ...@@ -2,39 +2,746 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import argparse
import datetime
import os import os
import subprocess import subprocess
from pathlib import Path import sys
import pathlib
import pytest import pytest
import torch import torch
from transformer_engine.pytorch import is_fp8_available, is_fp8_block_scaling_available
# NVTE_DISABLE_NVRTC=1 NVTE_INT8_SIM_FP8=1 torchrun --nproc_per_node=4 run_cast_master_weights_to_fp8.py --quantization fp8_block # NVTE_DISABLE_NVRTC=1 NVTE_INT8_SIM_FP8=1 torchrun --nproc_per_node=4 run_cast_master_weights_to_fp8.py --quantization fp8_block
if torch.cuda.device_count() < 2: from torch import nn
pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.") import torch.distributed as dist
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) from transformer_engine.common.recipe import (
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available( DelayedScaling,
return_reason=True Float8CurrentScaling,
Float8BlockScaling,
Format,
Recipe,
) )
import transformer_engine.pytorch as te
from transformer_engine.pytorch import (
is_fp8_available,
is_fp8_block_scaling_available,
QuantizedTensor,
Float8Tensor,
Float8BlockwiseQTensor,
)
from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8
from transformer_engine.pytorch.tensor.utils import post_all_gather_processing, replace_raw_data
def _get_quantization_recipe(quantization) -> Recipe:
"""Quantization recipe setup"""
fp8_format = Format.HYBRID
if quantization == "fp8":
return DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
elif quantization == "fp8_cs":
return Float8CurrentScaling(fp8_format=fp8_format)
elif quantization == "fp8_block":
return Float8BlockScaling(fp8_format=fp8_format)
else:
raise ValueError(f"Unsupported quantization: {quantization}")
def _get_raw_data(quantized_tensor):
"""Get the underlying data of a quantized tensor, used in zero-1 optimizer"""
if isinstance(quantized_tensor, Float8Tensor):
assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute"
assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8"
return quantized_tensor._data
elif isinstance(quantized_tensor, Float8BlockwiseQTensor):
assert hasattr(
quantized_tensor, "_rowwise_data"
), "Float8BlockwiseQTensor does not have _rowwise_data attribute"
assert (
quantized_tensor._rowwise_data.dtype == torch.uint8
), "Float8BlockwiseQTensor _rowwise_data must be uint8"
return quantized_tensor._rowwise_data
else:
raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}")
class MiniOptimizer:
def __init__(self, weights, lr, dp_group):
self.world_size = dist.get_world_size(dp_group)
self.weights = weights
self.lr = lr
self.dp_group = dp_group
master_weights = []
for weight in self.weights:
master_weights.append(weight.detach().float())
self.master_weights = master_weights
def step(self):
for weight, master_weight in zip(self.weights, self.master_weights):
main_grad = weight.main_grad
# Don't use all-reduce directly to explicitly control the reduce order.
# dist.all_reduce(main_grad, op=dist.ReduceOp.AVG, group=self.dp_group)
buffers = [torch.empty_like(main_grad) for _ in range(self.world_size)]
dist.all_gather(buffers, main_grad, group=self.dp_group)
for i in range(1, self.world_size):
buffers[0] += buffers[i]
main_grad.copy_(buffers[0])
main_grad /= self.world_size
master_weight -= main_grad * self.lr
weight.data.copy_(master_weight)
class MiniZero_1:
"""A mini zero-1 optimizer implementation, just used for this test"""
def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=False):
self.rank = dist.get_rank(dp_group)
self.world_size = dist.get_world_size(dp_group)
self.weights = weights
self.lr = lr
self.dp_group = dp_group
self.manual_post_all_gather_processing = manual_post_all_gather_processing
# [self.offsets[i], self.offsets[i+1]) is the range of weights[i] in the global buffer
self.offsets = [0]
for weight in self.weights:
self.offsets.append(self.offsets[-1] + weight.numel())
# Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may
# not be the end range of the last weight.
if self.offsets[-1] % self.world_size != 0:
self.offsets[-1] += self.world_size - self.offsets[-1] % self.world_size
self.master_weights = []
# The start offset of the master weight in the weight
self.start_offsets = []
# The overlapping area of the weight and this rank's local buffer
self.overlapping_areas = []
# The start and end of this rank's local buffer in the global buffer
rank_start = self.offsets[-1] // self.world_size * self.rank
rank_end = rank_start + self.offsets[-1] // self.world_size
for weight, offset in zip(self.weights, self.offsets[:-1]):
if offset >= rank_end or (offset + weight.numel()) <= rank_start:
# This weight is not in this rank's local buffer
master_weight = None
start_offset = None
overlapping_area = None
else:
overlapping_start = max(rank_start, offset)
overlapping_end = min(rank_end, offset + weight.numel())
length = overlapping_end - overlapping_start
start_offset = overlapping_start - offset
if isinstance(weight, QuantizedTensor):
# If weight is a FP8 tensor, we need to use the original high precision version
# to initialize the master weight.
high_precision_init_val = weight.get_high_precision_init_val().view(-1)
master_weight = high_precision_init_val.to(weight.device).float()[
start_offset : start_offset + length
]
else:
master_weight = (
weight.detach().view(-1).float()[start_offset : start_offset + length]
)
overlapping_area = (overlapping_start, overlapping_end)
self.master_weights.append(master_weight)
self.start_offsets.append(start_offset)
self.overlapping_areas.append(overlapping_area)
# Create global buffer for grads reduce-scatter
self.grad_buffer = torch.empty(
[self.offsets[-1]], dtype=torch.float32, device=weights[0].device
)
self.grad_buffer_slice = self.grad_buffer[rank_start:rank_end]
# Create global buffer for weights all-gather
if isinstance(self.weights[0], QuantizedTensor):
weight_buffer_dtype = torch.uint8
else:
weight_buffer_dtype = weights[0].dtype
self.weight_buffer = torch.empty(
[self.offsets[-1]], dtype=weight_buffer_dtype, device=weights[0].device
)
self.weight_buffer_slice = self.weight_buffer[rank_start:rank_end]
def step(self):
# -----------------------------------------------------------------------------------------
# Step 1: Copy grads to the grad buffer
# -----------------------------------------------------------------------------------------
for weight, offset in zip(self.weights, self.offsets[:-1]):
start = offset
end = offset + weight.numel()
self.grad_buffer[start:end].copy_(weight.main_grad.view(-1))
# -----------------------------------------------------------------------------------------
# Step 2: Grads reduce-scatter
# -----------------------------------------------------------------------------------------
# Don't use reduce_scatter directly to explicitly control the reduce order.
# dist.reduce_scatter_tensor(self.grad_buffer_slice, self.grad_buffer, op=dist.ReduceOp.AVG,
# group=self.dp_group)
buffers = [torch.empty_like(self.grad_buffer) for _ in range(self.world_size)]
dist.all_gather(buffers, self.grad_buffer, group=self.dp_group)
for i in range(1, self.world_size):
buffers[0] += buffers[i]
rank_start = self.offsets[-1] // self.world_size * self.rank
rank_end = rank_start + self.offsets[-1] // self.world_size
self.grad_buffer_slice.copy_(buffers[0][rank_start:rank_end])
self.grad_buffer_slice /= self.world_size
# -----------------------------------------------------------------------------------------
# Step 3: Update master weights
# -----------------------------------------------------------------------------------------
for master_weight, overlapping_area in zip(self.master_weights, self.overlapping_areas):
if master_weight is None:
# This weight's master weight is in other rank.
continue
grad = self.grad_buffer[overlapping_area[0] : overlapping_area[1]]
master_weight -= grad * self.lr
# -----------------------------------------------------------------------------------------
# Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight
# -----------------------------------------------------------------------------------------
if isinstance(self.weights[0], QuantizedTensor):
# FP8 weights case
for i in range(1, len(self.weights)):
assert isinstance(self.weights[i], QuantizedTensor)
cast_master_weights_to_fp8(
self.weights,
self.master_weights,
self.start_offsets,
self.dp_group,
manual_post_all_gather_processing=self.manual_post_all_gather_processing,
)
else:
# BF16 weights case
for weight, master_weight, start_offset in zip(
self.weights, self.master_weights, self.start_offsets
):
if master_weight is None:
continue
start = start_offset
end = start_offset + master_weight.numel()
weight.data.view(-1)[start:end].copy_(master_weight)
# -----------------------------------------------------------------------------------------
# Step 5: Copy the updated weights (not all weights) to the weight buffer
# -----------------------------------------------------------------------------------------
for i in range(len(self.weights)):
master_weight = self.master_weights[i]
if master_weight is None:
continue
start_offset = self.start_offsets[i]
if isinstance(self.weights[i], QuantizedTensor):
weight = _get_raw_data(self.weights[i])
else:
weight = self.weights[i]
weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()]
overlapping_start, overlapping_end = self.overlapping_areas[i]
self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice)
# -----------------------------------------------------------------------------------------
# Step 6: Weight all-gather (FP8 or BF16)
# -----------------------------------------------------------------------------------------
dist.all_gather_into_tensor(
self.weight_buffer, self.weight_buffer_slice, group=self.dp_group
)
# -----------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# -----------------------------------------------------------------------------------------
for weight, offset in zip(self.weights, self.offsets[:-1]):
start = offset
end = offset + weight.numel()
if isinstance(weight, QuantizedTensor):
weight = _get_raw_data(weight)
weight.view(-1).data.copy_(self.weight_buffer[start:end])
if self.manual_post_all_gather_processing:
quantized_weights = [
weight for weight in self.weights if isinstance(weight, QuantizedTensor)
]
post_all_gather_processing(quantized_weights)
class MiniFSDP:
def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=False):
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
self.weights = weights
self.lr = lr
self.dp_group = dp_group
self.manual_post_all_gather_processing = manual_post_all_gather_processing
# Flatten the weights and pad to align with world size
if isinstance(weights[0], QuantizedTensor):
raw_data_list = [_get_raw_data(w).view(-1) for w in weights]
else:
raw_data_list = [w.view(-1) for w in weights]
self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list)
# Split flattened weights into shards
self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank]
self.local_main_grad_shard = torch.zeros_like(
self.local_weight_shard, dtype=torch.float32, device="cuda"
)
shard_size = self.flatten_weight.size(0) // world_size
# Map original tensors to flattened indices
tensor_indices = []
cumulative_length = 0
for tensor in raw_data_list:
length = tensor.size(0)
tensor_indices.append((cumulative_length, cumulative_length + length))
cumulative_length += length
# Build shard index mappings
self.weight_indices = []
self.shard_indices = []
for idx, (start, end) in enumerate(tensor_indices):
shard_start = rank * shard_size
shard_end = shard_start + shard_size
adjusted_end = min(shard_end, original_length)
if start <= adjusted_end and end >= shard_start:
start_idx = max(start, shard_start)
end_idx = min(end, adjusted_end)
self.weight_indices.append((start_idx - start, end_idx - start))
self.shard_indices.append((start_idx - shard_start, end_idx - shard_start))
else:
self.weight_indices.append((None, None))
self.shard_indices.append((None, None))
if isinstance(weights[idx], QuantizedTensor):
replace_raw_data(
weights[idx], self.flatten_weight[start:end].view(weights[idx].shape)
)
else:
weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape)
# Initialize local model weights and high-precision master weights
self.local_weights = []
self.master_weights = []
for i, weight in enumerate(self.weights):
weight_start, weight_end = self.weight_indices[i]
shard_start, shard_end = self.shard_indices[i]
if shard_start is not None and shard_end is not None:
local_weight_shard = self.local_weight_shard[shard_start:shard_end]
self.local_weights.append(local_weight_shard)
if isinstance(weight, QuantizedTensor):
high_precision_init_val = weight.get_high_precision_init_val().view(-1)
master_weight_shard = high_precision_init_val.to(weight.device).float()[
weight_start:weight_end
]
else:
master_weight_shard = weight.detach().view(-1).float()[weight_start:weight_end]
self.master_weights.append(master_weight_shard)
else:
self.local_weights.append(None)
self.master_weights.append(None)
setattr(
weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda")
)
def _flatten_tensors_with_pad(self, tensors):
"""
Flatten the list of tensors and pad them to align with the world size.
Args:
tensors (list): List of tensors to flatten.
Returns:
tuple: Flattened tensor and its original length before padding.
"""
world_size = dist.get_world_size(self.dp_group)
flatten_tensor = torch.cat(tensors)
original_length = flatten_tensor.size(0)
padding_needed = (world_size - original_length % world_size) % world_size
if padding_needed > 0:
zeros = torch.zeros(padding_needed, dtype=flatten_tensor.dtype, device="cuda")
flatten_tensor = torch.cat([flatten_tensor, zeros])
return flatten_tensor, original_length
def zero_grad(self):
for weight in self.weights:
weight.grad = None
weight.main_grad.zero_()
def step(self):
"""
Perform an optimization step for the distributed sharded model.
This method includes:
1. Gradient reduce-scatter: Synchronize gradients across all processes.
2. Master weight update: Update high-precision master weights using local gradients.
3. Precision casting: Cast updated master weights to FP8 or BF16 precision.
4. Weight synchronization: All-gather updated weights across all processes.
Returns:
None
"""
# Step 1: Reduce-scatter the gradients
main_grad_buffer, _ = self._flatten_tensors_with_pad(
[weight.main_grad.view(-1) for weight in self.weights]
)
dist.reduce_scatter_tensor(
self.local_main_grad_shard, main_grad_buffer, group=self.dp_group
)
self.local_main_grad_shard /= dist.get_world_size(self.dp_group)
# Step 2: Update the master weights
for weight, master_weight, (shard_start, shard_end) in zip(
self.weights, self.master_weights, self.shard_indices
):
if master_weight is None:
continue
# Extract the local gradient shard for this weight
grad = self.local_main_grad_shard[shard_start:shard_end]
# Update the master weight using gradient descent
master_weight -= grad * self.lr
# Step 3: Cast master weights to FP8 or BF16 precision
if isinstance(self.weights[0], QuantizedTensor):
local_weights = []
for local_weight in self.local_weights:
if local_weight is None:
local_weights.append(None)
continue
local_weights.append(local_weight)
cast_master_weights_to_fp8(
self.weights,
self.master_weights,
[idx[0] for idx in self.weight_indices],
self.dp_group,
local_weights,
manual_post_all_gather_processing=self.manual_post_all_gather_processing,
)
else:
for weight, master_weight in zip(self.local_weights, self.master_weights):
if master_weight is None:
continue
TEST_ROOT = Path(__file__).parent.resolve() # Copy updated master weights to local weights
NUM_PROCS: int = min(2, torch.cuda.device_count()) weight.data.copy_(master_weight)
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
# Step 4: All-gather updated weights across processes
dist.all_gather_into_tensor(
self.flatten_weight, self.local_weight_shard, group=self.dp_group
)
if self.manual_post_all_gather_processing:
quantized_weights = [
weight for weight in self.weights if isinstance(weight, QuantizedTensor)
]
post_all_gather_processing(quantized_weights)
def _test_mini_optimizer(dp_group):
"""Make sure the implementation of MiniZero_1 and MiniFSDP is correct"""
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
torch.manual_seed(12345)
torch.cuda.manual_seed(12345)
weights = [
torch.randn(256 * 256, dtype=torch.bfloat16, device="cuda"),
torch.randn(256 * 256 * 3, dtype=torch.bfloat16, device="cuda"),
torch.randn(256 * 256 * 2 - 1, dtype=torch.bfloat16, device="cuda"),
]
weights_1 = weights
weights_2 = [weight.clone() for weight in weights]
weights_3 = [weight.clone() for weight in weights]
lr = 1.0
optimizer_1 = MiniZero_1(weights_1, lr, dp_group)
optimizer_2 = MiniOptimizer(weights_2, lr, dp_group)
optimizer_3 = MiniFSDP(weights_3, lr, dp_group)
for _ in range(100):
for w1, w2, w3 in zip(weights_1, weights_2, weights_3):
main_grads = [
torch.randn_like(w1, dtype=torch.float32, device="cuda") for _ in range(world_size)
]
# Choose based on rank to make sure the grads of different ranks are different.
main_grad = main_grads[rank]
w1.main_grad = main_grad
w2.main_grad = main_grad
w3.main_grad = main_grad
optimizer_1.step()
optimizer_2.step()
optimizer_3.step()
for w1, w2 in zip(weights_1, weights_2):
torch.testing.assert_close(w1, w2, atol=0, rtol=0)
for w1, w3 in zip(weights_1, weights_3):
torch.testing.assert_close(w1, w3, atol=0, rtol=0)
def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gather_processing):
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
torch.manual_seed(12345)
torch.cuda.manual_seed(12345)
mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)]
mock_group = mock_groups[rank]
linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True}
# Create model with FP8 weights
with te.quantized_model_init(
enabled=quantization is not None,
recipe=_get_quantization_recipe(quantization),
preserve_high_precision_init_val=True,
):
model_fp8 = nn.Sequential(
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Create model with BF16 weights
model = nn.Sequential(
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Make sure the BF16 model and FP8 model have the same initial weights
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
high_precision_init_val = w_fp8.get_high_precision_init_val()
w.data.copy_(high_precision_init_val)
# Allocate main_grads for each weight
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
w_fp8.main_grad = torch.zeros_like(w_fp8, dtype=torch.float32, device="cuda")
w.main_grad = torch.zeros_like(w, dtype=torch.float32, device="cuda")
optimizer_fp8 = MiniZero_1(
[w for w in model_fp8.parameters()], 10.0, dp_group, manual_post_all_gather_processing
)
optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group)
for i in range(100):
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
w_fp8.main_grad.zero_()
w.main_grad.zero_()
inputs = [
torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank]
with te.autocast(
enabled=quantization is not None,
recipe=_get_quantization_recipe(quantization),
amax_reduction_group=mock_group,
):
y_fp8 = model_fp8(x)
with te.autocast(
enabled=quantization is not None,
recipe=_get_quantization_recipe(quantization),
amax_reduction_group=mock_group,
):
y = model(x)
targets = [torch.randn_like(y) for _ in range(world_size)]
# Choose based on rank to make sure the targets of different ranks are different.
target = targets[rank]
loss_fp8 = nn.MSELoss()(y_fp8, target)
loss = nn.MSELoss()(y, target)
loss_fp8.backward()
loss.backward()
optimizer_fp8.step()
optimizer.step()
torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
def _test_fsdp_cast_master_weights_to_fp8(
quantization, dp_group, manual_post_all_gather_processing
):
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
# Configuration constants
NUM_STEPS = 100
SEED = 12345
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)]
mock_group = mock_groups[rank]
linear_kwargs = {
"params_dtype": torch.bfloat16,
"bias": False,
"fuse_wgrad_accumulation": True,
}
# Create model with FP8 weights
with te.quantized_model_init(
enabled=quantization is not None,
recipe=_get_quantization_recipe(quantization),
preserve_high_precision_init_val=True,
):
model_fp8 = nn.Sequential(
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Create model with BF16 weights
model = nn.Sequential(
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Make sure the BF16 model and FP8 model have the same initial weights
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
high_precision_init_val = w_fp8.get_high_precision_init_val()
w.data.copy_(high_precision_init_val)
optimizer_fp8 = MiniFSDP(
[w for w in model_fp8.parameters()], 10.0, dp_group, manual_post_all_gather_processing
)
optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group)
for _ in range(100):
optimizer_fp8.zero_grad()
optimizer.zero_grad()
inputs = [
torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank]
with te.autocast(
enabled=quantization is not None,
recipe=_get_quantization_recipe(quantization),
amax_reduction_group=mock_group,
):
y_fp8 = model_fp8(x)
with te.autocast(
enabled=quantization is not None,
recipe=_get_quantization_recipe(quantization),
amax_reduction_group=mock_group,
):
y = model(x)
targets = [torch.randn_like(y) for _ in range(world_size)]
# Choose based on rank to make sure the targets of different ranks are different.
target = targets[rank]
loss_fp8 = nn.MSELoss()(y_fp8, target)
loss = nn.MSELoss()(y, target)
loss_fp8.backward()
loss.backward()
optimizer_fp8.step()
optimizer.step()
torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
def run_parallel_tests() -> None:
"""Run parallel tests"""
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node
assert LOCAL_SIZE <= torch.cuda.device_count()
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
"timeout": datetime.timedelta(seconds=30),
}
dist_init_kwargs["init_method"] = "env://"
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
torch.cuda.set_device(LOCAL_RANK)
dist.init_process_group(**dist_init_kwargs)
dp_group = dist.new_group(backend="nccl")
quantizations = []
if is_fp8_available():
quantizations.extend(["fp8", "fp8_cs"])
if is_fp8_block_scaling_available():
quantizations.append("fp8_block")
manual_post_all_gather_processings = [False, True]
_test_mini_optimizer(dp_group)
for quantization in quantizations:
for post_ag_processing in manual_post_all_gather_processings:
_test_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing)
_test_fsdp_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing)
dist.destroy_process_group()
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="cast_master_weights_to_fp8 test needs at least 2 GPUs."
)
@pytest.mark.parametrize("world_size", [2])
def test_cast_master_weights_to_fp8(world_size: int) -> None:
"""Launch parallel job that runs parallel tests"""
python_exe = pathlib.Path(sys.executable).resolve()
current_file = pathlib.Path(__file__).resolve()
command = [
python_exe,
"-m",
"torch.distributed.run",
f"--nproc_per_node={world_size}",
current_file,
"--parallel",
]
result = subprocess.run(
command,
check=True,
)
def _run_test(quantization): def main() -> None:
test_path = TEST_ROOT / "run_cast_master_weights_to_fp8.py" parser = argparse.ArgumentParser()
test_cmd = LAUNCH_CMD + [str(test_path)] + ["--quantization", quantization] parser.add_argument("--parallel", action="store_true", help="Run parallel tests")
result = subprocess.run(test_cmd, env=os.environ, check=False) args = parser.parse_args()
assert result.returncode == 0 if args.parallel:
run_parallel_tests()
@pytest.mark.parametrize("quantization", ["fp8", "fp8_cs", "fp8_block"]) if __name__ == "__main__":
def test_cast_master_weights_to_fp8(quantization): main()
if quantization in ("fp8", "fp8_cs") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "fp8_block" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
_run_test(quantization)
...@@ -34,6 +34,7 @@ from transformer_engine.pytorch import ( ...@@ -34,6 +34,7 @@ from transformer_engine.pytorch import (
Float8Tensor, Float8Tensor,
) )
# Import utility functions # Import utility functions
_current_file = pathlib.Path(__file__).resolve() _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent)) sys.path.append(str(_current_file.parent.parent))
......
...@@ -14,7 +14,7 @@ import transformer_engine.pytorch as te ...@@ -14,7 +14,7 @@ import transformer_engine.pytorch as te
Distributed numerics tests Distributed numerics tests
This numerical test aims for zero tolerance test for absolute confidence in numerics. This numerical test aims for zero tolerance test for absolute confidence in numerics.
In the case of NVFP4, with the experimental NVFP4 quantization, we matched bitwise In the case of NVFP4, with the custom NVFP4 quantization, we matched bitwise
result with the native silicon. For distrbuted test cases, we can do the same by thing result with the native silicon. For distrbuted test cases, we can do the same by thing
by comparing BF16 AG results with the low precision AG results at layer level. by comparing BF16 AG results with the low precision AG results at layer level.
""" """
......
...@@ -12,22 +12,26 @@ import transformer_engine.pytorch as te ...@@ -12,22 +12,26 @@ import transformer_engine.pytorch as te
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
NUM_PROCS: int = torch.cuda.device_count() NUM_PROCS: int = torch.cuda.device_count()
def _run_test(fp_init, sharding_dims): def _run_test(fp_init, sharding_dims, recipe, layer_type):
test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py" test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py"
test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)] test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)]
if fp_init: if fp_init:
test_cmd += ["--fp8-init"] test_cmd += ["--fp8-init"]
if len(sharding_dims) == 1: if len(sharding_dims) == 1:
test_cmd += ["--sharding-dims", str(sharding_dims[0])] test_cmd += ["--sharding-dims", str(sharding_dims[0])]
elif len(sharding_dims) == 2: elif len(sharding_dims) == 2:
test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])]
else: else:
assert False assert False
test_cmd += ["--recipe", recipe]
test_cmd += ["--layer-type", layer_type]
result = subprocess.run(test_cmd, env=os.environ, check=True) result = subprocess.run(test_cmd, env=os.environ, check=True)
...@@ -36,16 +40,20 @@ def _run_test(fp_init, sharding_dims): ...@@ -36,16 +40,20 @@ def _run_test(fp_init, sharding_dims):
@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") @pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) @pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2]))
@pytest.mark.parametrize("fp8_init", (False, True)) @pytest.mark.parametrize("fp8_init", (False, True))
def test_distributed(fp8_init, sharding_dims): @pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling"))
@pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer"))
def test_distributed(fp8_init, sharding_dims, recipe, layer_type):
# Skip invalid configurations # Skip invalid configurations
if torch.cuda.device_count() < 4: if torch.cuda.device_count() < 4:
pytest.skip("FSDP2 test requires at least 4 GPUs") pytest.skip("FSDP2 test requires at least 4 GPUs")
if fp8_init and not fp8_available: if recipe == "mx_fp8_block_scaling" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
elif not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
_run_test(fp8_init, sharding_dims) _run_test(fp8_init, sharding_dims, recipe, layer_type)
def test_dummy() -> None: def test_dummy() -> None:
......
...@@ -8,8 +8,8 @@ import transformer_engine.pytorch as te ...@@ -8,8 +8,8 @@ import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils from transformer_engine.pytorch.custom_recipes import utils
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
......
...@@ -6,8 +6,8 @@ import pytest ...@@ -6,8 +6,8 @@ import pytest
import torch import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch.experimental import quantization_nvfp4 from transformer_engine.pytorch.custom_recipes import quantization_nvfp4
from transformer_engine.pytorch.experimental import utils from transformer_engine.pytorch.custom_recipes import utils
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
......
...@@ -7,10 +7,10 @@ import torch ...@@ -7,10 +7,10 @@ import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.common.recipe import NVFP4BlockScaling from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.experimental import utils
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.common.recipe import NVFP4BlockScaling from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.experimental import utils
import pytest import pytest
import torch import torch
......
...@@ -2,27 +2,41 @@ ...@@ -2,27 +2,41 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import random
import contextlib import contextlib
import gc
import os
from typing import Iterable, Optional
import pytest import pytest
import os
import torch import torch
from typing import Optional, List
from transformer_engine.pytorch.cpu_offload import (
get_cpu_offload_context,
OffloadableLayerState,
DefaultOffloadSynchronizer,
start_offload,
mark_not_offload,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends from utils import ModelConfig
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported import transformer_engine_torch as tex
from utils import ModelConfig, get_available_attention_backends
# Check supported quantization schemes # Check supported quantization schemes
fp8_available = te.is_fp8_available() fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
mxfp8_available = te.is_mxfp8_available() fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
nvfp4_available, _ = FP8GlobalStateManager.is_nvfp4_available()
quantization_recipes: Optional[recipe.Recipe] = [None] quantization_recipes: List[Optional[recipe.Recipe]] = [None]
if fp8_available: if fp8_available:
quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling())) quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling()))
if fp8_block_scaling_available:
quantization_recipes.append(recipe.Float8BlockScaling())
if mxfp8_available:
quantization_recipes.append(recipe.MXFP8BlockScaling())
if nvfp4_available:
quantization_recipes.append(recipe.NVFP4BlockScaling())
model_config = { model_config = {
"small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1), "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
...@@ -32,181 +46,709 @@ NUM_HEADS = model_config["small"].num_heads ...@@ -32,181 +46,709 @@ NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps EPSILON = model_config["small"].eps
# Flash attention saves some internal tensor for the backward pass # Disable garbage collection to tests if there are reference cycles.
# that cannot be offloaded to CPU. # We do not want them, because they can result in CUDA out of memory errors.
assert os.getenv("NVTE_FLASH_ATTN") == "0" import gc
gc.disable()
# Offloading is supported for attention only for fused and flash attention backends,
# so the use of bfloat16 is required. class Utils:
# tensor1 = torch.randn((1024, 1024), device="cuda", dtype=torch.bfloat16)
# For the TransformerLayer, activation offloading with dropout is not supported, _B = 64
# so we set hidden_dropout to 0.0. _S = 256
model_types = { _H = 4
"linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16), _D = 256
"layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16), @staticmethod
"multihead_attention": lambda: te.MultiheadAttention( def long_job(stream: Optional[torch.cuda.Stream] = None):
SIZE, NUM_HEADS, params_dtype=torch.bfloat16 NUM_ITERS = 6000
), if stream is None:
"transformer_layer": lambda: te.TransformerLayer( stream = torch.cuda.current_stream()
SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
), with torch.cuda.stream(stream):
"linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), for i in range(NUM_ITERS):
"layernorm_mlp_ops": lambda: te.ops.Sequential( Utils.tensor1.normal_()
te.ops.LayerNorm(SIZE, dtype=torch.bfloat16),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), @staticmethod
def measure_time(func):
import time
torch.cuda.synchronize()
start = time.time()
func()
torch.cuda.synchronize()
end = time.time()
return (end - start) * 1000
@staticmethod
def get_cuda_memory_mb():
return torch.cuda.memory_allocated() / (1024**2)
@staticmethod
def get_max_cuda_memory_mb():
return torch.cuda.max_memory_allocated() / (1024**2)
@staticmethod
def get_cpu_memory_mb() -> float:
import psutil, os
return psutil.Process(os.getpid()).memory_info().rss / (1024**2)
@staticmethod
def get_layer_names():
return [
"linear",
"layernorm_linear",
"layernorm_mlp",
"grouped_linear",
"multihead_attention",
"transformer_layer",
"linear_op",
"layernorm_mlp_ops",
]
@staticmethod
def create_layer(layer_type: str):
if layer_type == "linear":
return te.Linear(Utils._D, Utils._D, params_dtype=torch.bfloat16)
elif layer_type == "layernorm_linear":
return te.LayerNormLinear(Utils._D, Utils._D, params_dtype=torch.bfloat16)
elif layer_type == "layernorm_mlp":
return te.LayerNormMLP(Utils._D, Utils._D, params_dtype=torch.bfloat16)
elif layer_type == "multihead_attention":
return te.MultiheadAttention(
Utils._D, Utils._H, attention_dropout=0.0, params_dtype=torch.bfloat16
)
elif layer_type == "grouped_linear":
return te.GroupedLinear(Utils._H, Utils._D, Utils._D, params_dtype=torch.bfloat16)
elif layer_type == "transformer_layer":
return te.TransformerLayer(
Utils._D,
Utils._D,
Utils._H,
attention_dropout=0.0,
hidden_dropout=0.0,
params_dtype=torch.bfloat16,
)
elif layer_type == "linear_op":
return te.ops.Linear(Utils._D, Utils._D, dtype=torch.bfloat16)
elif layer_type == "layernorm_mlp_ops":
return te.ops.Sequential(
te.ops.LayerNorm(Utils._D, dtype=torch.bfloat16),
te.ops.Linear(Utils._D, Utils._D, dtype=torch.bfloat16),
te.ops.GELU(), te.ops.GELU(),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), te.ops.Linear(Utils._D, Utils._D, dtype=torch.bfloat16),
), )
} else:
raise ValueError(f"Unknown layer type: {layer_type}")
@staticmethod
def create_tensor(recipe: Optional[recipe.Recipe], requires_grad: bool = False) -> torch.Tensor:
shape = (Utils._B, Utils._S, Utils._D)
tensor = torch.randn(shape, device="cuda", dtype=torch.bfloat16)
if recipe is None:
tensor = tensor.requires_grad_() if requires_grad else tensor
return tensor
elif recipe.delayed():
quantizer = te.tensor.float8_tensor.Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
scale=torch.tensor([1.0], device="cuda"),
amax=torch.tensor([1.0], device="cuda"),
)
return quantizer(tensor)
elif recipe.float8_current_scaling():
quantizer = te.tensor.float8_tensor.Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"
)
return quantizer(tensor)
elif recipe.float8_block_scaling():
quantizer = te.tensor.float8_blockwise_tensor.Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True
)
return quantizer(tensor)
elif recipe.mxfp8():
quantizer = te.tensor.mxfp8_tensor.MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
return quantizer(tensor)
elif recipe.nvfp4():
quantizer = te.tensor.nvfp4_tensor.NVFP4Quantizer()
return quantizer(tensor)
@staticmethod
def create_recipe_ctx(recipe: Optional[recipe.Recipe]):
if recipe is None:
return lambda: contextlib.nullcontext()
else:
return lambda: te.fp8_autocast(fp8_recipe=recipe)
@staticmethod
def get_tensor_size_mb(tensor):
if tensor is None:
return 0
if isinstance(tensor, te.quantized_tensor.QuantizedTensorStorage):
return sum(Utils.get_tensor_size_mb(t) for t in tensor.get_data_tensors())
else:
return tensor.numel() * tensor.element_size() / (1024**2)
@staticmethod
def memory_leak_check():
# Should be called before each test.
# Only cublas workspaces and some global tensors are allowed to be allocated.
# All other allocations should be released.
# This is a simple check to catch memory leaks.
if Utils.get_cuda_memory_mb() > 1000:
memory_num = Utils.get_cuda_memory_mb()
import gc
gc.collect() # We want next test to be run with clean state.
gc.disable()
raise RuntimeError(f"Memory leak: {memory_num} MB")
class TestsOffloadableLayerState:
@pytest.mark.parametrize("random_num_tensors", [True, False])
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_general(self, random_num_tensors, recipe):
"""
Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers,
for each layer offload random number of random tensors.
Then do backward pass for each layer, and check if reloaded tensors are equal to original tensors.
"""
Utils.memory_leak_check()
NUM_ITERATIONS = 10
stream = torch.cuda.Stream()
offload_layer_state = OffloadableLayerState(
offload_stream=stream,
)
for _ in range(NUM_ITERATIONS):
original_tensors = []
tensors_ids = []
NUM_TENSORS = random.choice([1, 20]) if random_num_tensors else 1
for _ in range(NUM_TENSORS):
tensor = Utils.create_tensor(recipe)
original_tensors.append(tensor)
tensor_id = offload_layer_state.push_tensor(tensor)
assert tensor.device.type == "cuda"
tensors_ids.append(tensor_id)
offload_layer_state.start_offload()
offload_layer_state.release_activation_forward_gpu_memory()
offload_layer_state.start_reload()
for j in range(len(tensors_ids)):
tensor_gpu = offload_layer_state.pop_tensor(tensors_ids[j])
assert tensor_gpu.device.type == "cuda"
assert tensor_gpu.shape == original_tensors[j].shape
assert tensor_gpu.dtype == original_tensors[j].dtype
torch.testing.assert_close(tensor_gpu, original_tensors[j])
offload_layer_state.release_all_memory()
torch.cuda.synchronize()
def test_offload_base_tensor(self):
Utils.memory_leak_check()
stream = torch.cuda.Stream()
offload_layer_state = OffloadableLayerState(
offload_stream=stream,
)
init_cuda_memory = Utils.get_cuda_memory_mb()
x = Utils.create_tensor(None)
x_size = Utils.get_tensor_size_mb(x)
x_1 = x[::2]
x_2 = x[1::2]
start_offload(x_1, offload_base_tensor=True)
start_offload(x_2, offload_base_tensor=True)
x1_id = offload_layer_state.push_tensor(x_1)
x2_id = offload_layer_state.push_tensor(x_2)
del x_1, x_2
offload_layer_state.start_offload()
offload_layer_state.release_activation_forward_gpu_memory()
assert offload_layer_state.get_offloaded_total_size_mb() == pytest.approx(x_size, 0.1)
offload_layer_state.start_reload()
x_1 = offload_layer_state.pop_tensor(x1_id)
x_2 = offload_layer_state.pop_tensor(x2_id)
assert x_1.device.type == "cuda"
assert x_2.device.type == "cuda"
assert torch.allclose(x_1, x[::2])
assert torch.allclose(x_2, x[1::2])
del x
assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory + x_size, 0.1)
class TestsDefaultOffloadSynchronizer:
@pytest.mark.parametrize("random_num_tensors", [True, False])
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_general(self, random_num_tensors, recipe):
"""
Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers,
for each layer offload random number of random tensors.
Then do backward pass for each layer, and check if reloaded tensors are equal to original tensors.
"""
Utils.memory_leak_check()
NUM_LAYERS = 10
NUM_ITERATIONS = 10
offload_synchronizer = DefaultOffloadSynchronizer(
num_layers=NUM_LAYERS,
num_offloaded_layers=NUM_LAYERS - 1,
)
for _ in range(NUM_ITERATIONS):
original_tensors = []
tensors_ids = []
layer_ids = []
for i in range(NUM_LAYERS):
NUM_LAYER_TENSORS = random.randint(1, 10) if random_num_tensors else 1
layer_tensors = []
layer_tensors_ids = []
layer_id = offload_synchronizer.fwd_step()
for _ in range(NUM_LAYER_TENSORS):
tensor = Utils.create_tensor(recipe)
layer_tensors.append(tensor)
tensor_id = offload_synchronizer.push_tensor(tensor)
assert tensor.device.type == "cuda"
layer_tensors_ids.append(tensor_id)
layer_ids.append(layer_id)
tensors_ids.append(layer_tensors_ids)
original_tensors.append(layer_tensors)
for i in range(NUM_LAYERS - 1, -1, -1):
offload_synchronizer.bwd_step(layer_ids[i])
for j in range(len(tensors_ids[i])):
tensor_gpu = offload_synchronizer.pop_tensor(tensors_ids[i][j])
assert tensor_gpu.device.type == "cuda"
assert tensor_gpu.shape == original_tensors[i][j].shape
assert tensor_gpu.dtype == original_tensors[i][j].dtype
torch.testing.assert_close(tensor_gpu, original_tensors[i][j])
offload_synchronizer.finish_part_of_bwd()
torch.cuda.synchronize()
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_memory(self, recipe):
torch.cuda.synchronize()
Utils.memory_leak_check()
NUM_LAYERS = 10
torch.cuda.reset_peak_memory_stats()
offload_synchronizer = DefaultOffloadSynchronizer(
num_layers=NUM_LAYERS,
num_offloaded_layers=NUM_LAYERS - 1,
)
init_cuda_memory = Utils.get_cuda_memory_mb()
tensor_ids = []
torch.cuda.synchronize()
for _ in range(NUM_LAYERS):
offload_synchronizer.fwd_step()
tensor = Utils.create_tensor(recipe)
tensor_size = Utils.get_tensor_size_mb(tensor)
tensor_id = offload_synchronizer.push_tensor(tensor)
assert tensor.device.type == "cuda"
tensor_ids.append(tensor_id)
del tensor, tensor_id
torch.cuda.synchronize()
def _make_input() -> torch.Tensor: if recipe is None:
"""Generate random input tensor.""" assert Utils.get_max_cuda_memory_mb() == pytest.approx(
return torch.randn( init_cuda_memory + tensor_size, 0.1
(128, SIZE, SIZE),
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
) )
assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory + tensor_size, 0.1)
for i in range(NUM_LAYERS - 1, -1, -1):
offload_synchronizer.bwd_step(i)
tensor_gpu = offload_synchronizer.pop_tensor(tensor_ids[i])
assert tensor_gpu.device.type == "cuda"
del tensor_gpu, tensor_ids[i]
offload_synchronizer.finish_part_of_bwd()
del tensor_ids
torch.cuda.synchronize()
def _warmup_model( if recipe is None:
modules: Iterable[torch.nn.Module], assert Utils.get_max_cuda_memory_mb() == pytest.approx(
quantization_recipe: Optional[recipe.Recipe], init_cuda_memory + tensor_size, 0.1
) -> None: )
"""Perform forward and backward pass""" assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1)
tensor = _make_input()
for module in modules: @pytest.mark.parametrize("recipe", quantization_recipes)
with te.autocast( def test_multiple_tensor_offload(self, recipe):
enabled=quantization_recipe is not None, Utils.memory_leak_check()
recipe=quantization_recipe, init_cpu_memory = Utils.get_cpu_memory_mb()
init_cuda_memory = Utils.get_cuda_memory_mb()
offload_synchronizer = DefaultOffloadSynchronizer(
num_layers=2,
num_offloaded_layers=1,
)
x1 = Utils.create_tensor(recipe)
x_size = Utils.get_tensor_size_mb(x1)
offload_synchronizer.fwd_step()
offload_synchronizer.push_tensor(x1)
offload_synchronizer.push_tensor(x1)
offload_synchronizer.push_tensor(x1)
offload_synchronizer.fwd_step()
# Only one copy of tensor on cpu is allocated.
assert Utils.get_cpu_memory_mb() == pytest.approx(init_cpu_memory + 1 * x_size, 0.1)
del x1
offload_synchronizer.bwd_step(1)
offload_synchronizer.bwd_step(0)
offload_synchronizer.finish_part_of_bwd()
assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1)
class TestTELayers:
@pytest.mark.parametrize("layer_type", Utils.get_layer_names())
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_sanity(self, layer_type, recipe):
Utils.memory_leak_check()
# Skip ops-based layers with Float8BlockScaling recipe
if (
layer_type in ["linear_op", "layernorm_mlp_ops"]
and recipe is not None
and recipe.float8_block_scaling()
): ):
tensor = module(tensor) pytest.skip("Fusible operations do not support FP8 block scaling recipe")
tensor.sum().backward()
recipe_ctx = Utils.create_recipe_ctx(recipe)
init_cuda_memory = Utils.get_cuda_memory_mb()
OFFLOAD_LAYERS = 6
NUM_LAYERS = 10
offload_ctx, sync_function = get_cpu_offload_context(
enabled=True,
num_layers=OFFLOAD_LAYERS,
model_layers=NUM_LAYERS,
)
layers = [Utils.create_layer(layer_type) for _ in range(NUM_LAYERS)]
inp = Utils.create_tensor(None)
m_splits = (
{"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H}
if layer_type == "grouped_linear"
else {}
)
out = inp
for i in range(NUM_LAYERS):
with offload_ctx, recipe_ctx():
# Ops-based layers don't support is_first_microbatch parameter
if layer_type in ["linear_op", "layernorm_mlp_ops"]:
out = layers[i](out, **m_splits)
else:
out = layers[i](out, is_first_microbatch=False, **m_splits)
out = sync_function(out)
out.sum().backward()
torch.cuda.synchronize()
del out, inp, layers
@pytest.mark.parametrize("layer_type", Utils.get_layer_names())
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_memory(self, layer_type, recipe):
Utils.memory_leak_check()
# Skip ops-based layers with Float8BlockScaling recipe
if (
layer_type in ["linear_op", "layernorm_mlp_ops"]
and recipe is not None
and recipe.float8_block_scaling()
):
pytest.skip("Fusible operations do not support FP8 block scaling recipe")
def _estimate_cached_weight_size( offload_ctx, sync_function = get_cpu_offload_context(
model_name: str, enabled=True,
modules: Iterable[torch.nn.Module], num_layers=1,
quantization_recipe: Optional[recipe.Recipe], model_layers=2,
) -> float: offload_activations=True,
"""Calculate the memory (in MiB) needed for weight caching.""" offload_weights=False,
)
recipe_ctx = Utils.create_recipe_ctx(recipe)
layer = Utils.create_layer(layer_type)
inp = Utils.create_tensor(None)
m_splits = (
{"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H}
if layer_type == "grouped_linear"
else {}
)
# The weight params are cached directly for unquantized compute # Ops-based layers don't support is_first_microbatch parameter
if quantization_recipe is None: is_ops_layer = layer_type in ["linear_op", "layernorm_mlp_ops"]
return 0
# Count number of weight param elements with recipe_ctx():
param_elements = 0 if is_ops_layer:
for module in modules: out = layer(inp, **m_splits)
for param in module.parameters(): else:
if param.dim() == 2: out = layer(inp, is_first_microbatch=True, **m_splits)
param_elements += param.numel() out.sum().backward()
# FP8 tensor-scaling caches one byte per element del inp
if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling(): init_cuda_memory = Utils.get_cuda_memory_mb()
if not is_non_tn_fp8_gemm_supported() and model_name not in (
"linear_op", # run layer without offload
"layernorm_mlp_ops", inp = Utils.create_tensor(None)
with recipe_ctx():
if is_ops_layer:
out = layer(inp, **m_splits)
else:
out = layer(inp, is_first_microbatch=False, **m_splits)
with recipe_ctx():
out = out + 1
del inp
cuda_memory_no_offload = Utils.get_cuda_memory_mb()
out.sum().backward()
# run layer with offload
inp = Utils.create_tensor(None)
with offload_ctx, recipe_ctx():
if is_ops_layer:
out = layer(inp, **m_splits)
else:
out = layer(inp, is_first_microbatch=False, **m_splits)
out = sync_function(out)
with offload_ctx, recipe_ctx():
out = out + 1
out = sync_function(out)
del inp
assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1)
offloaded_memory_cpu = offload_ctx.offload_synchronizer.get_offloaded_total_size_mb()
# This assertion verifies that the memory used by tensors on the CPU matches the memory saved from a layer.
# It helps catch cases where an offloaded tensor still has a live pointer, which would
# cause an unnecessary copy to the CPU and prevent GPU memory from being released.
assert Utils.get_cuda_memory_mb() + offloaded_memory_cpu == pytest.approx(
cuda_memory_no_offload, 0.1
)
out.sum().backward()
@pytest.mark.parametrize("layer_type", Utils.get_layer_names())
@pytest.mark.parametrize("recipe", quantization_recipes)
def test_manual_synchronization(self, recipe, layer_type):
Utils.memory_leak_check()
# Skip ops-based layers with Float8BlockScaling recipe
if (
layer_type in ["linear_op", "layernorm_mlp_ops"]
and recipe is not None
and recipe.float8_block_scaling()
): ):
# Modules do not deallocate FP8 transpose for weights pytest.skip("Fusible operations do not support FP8 block scaling recipe")
return 2 * param_elements / 1024**2
return param_elements / 1024**2
# MXFP8 caches one data byte per element and one scale byte per 32 offload_ctx, sync_function, manual_controller = get_cpu_offload_context(
# elements enabled=True,
if quantization_recipe.mxfp8(): model_layers=6,
if model_name not in ("linear_op", "layernorm_mlp_ops"): offload_activations=True,
# Modules do not deallocate column-wise MXFP8 data for weights manual_synchronization=True,
return 2 * param_elements * (1 + 1 / 32) / 1024**2 )
return param_elements * (1 + 1 / 32) / 1024**2 layer_1 = Utils.create_layer(layer_type)
layer_2 = Utils.create_layer(layer_type)
inp1 = Utils.create_tensor(None)
inp2 = Utils.create_tensor(None)
raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})") recipe_ctx = Utils.create_recipe_ctx(recipe)
m_splits = (
{"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H}
if layer_type == "grouped_linear"
else {}
)
def _measure_cached_memory( init_cuda_memory = Utils.get_cuda_memory_mb()
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe], # 1 fwd
cpu_offload: bool, with offload_ctx, recipe_ctx():
) -> float: out_1 = layer_1(inp1, **m_splits)
"""Measure the growth in allocated GPU memory in MiB after a model forward pass. out_1 = sync_function(out_1)
with offload_ctx, recipe_ctx():
out_2 = layer_2(inp2, **m_splits)
out_2 = sync_function(out_2)
mark_not_offload(out_1, out_2)
del inp1, inp2
memory_before_offload = Utils.get_cuda_memory_mb()
manual_controller.start_offload_layer(0)
manual_controller.release_activation_forward_gpu_memory(0)
manual_controller.start_offload_layer(1)
manual_controller.release_activation_forward_gpu_memory(1)
memory_after_offload = Utils.get_cuda_memory_mb()
assert memory_after_offload + EPSILON < memory_before_offload
manual_controller.start_reload_layer(0)
manual_controller.start_reload_layer(1)
memory_after_reload = Utils.get_cuda_memory_mb()
assert memory_after_reload == pytest.approx(memory_before_offload, 0.1)
out_1.sum().backward()
out_2.sum().backward()
@pytest.mark.parametrize("recipe", quantization_recipes)
@pytest.mark.parametrize("layer_type", Utils.get_layer_names())
@pytest.mark.parametrize("use_cuda_graphs", [True, False])
@pytest.mark.parametrize("retain_pinned_cpu_buffers", [True, False])
@pytest.mark.parametrize("backend", ["FlashAttention", "FusedAttention", "UnfusedAttention"])
def test_numerics(
self,
recipe,
layer_type,
use_cuda_graphs,
backend,
retain_pinned_cpu_buffers,
):
# Skip ops-based layers with Float8BlockScaling recipe
if (
layer_type in ["linear_op", "layernorm_mlp_ops"]
and recipe is not None
and recipe.float8_block_scaling()
):
pytest.skip("Fusible operations do not support FP8 block scaling recipe")
Memory measurement excludes the input and output tensors. recipe_ctx = Utils.create_recipe_ctx(recipe)
""" if use_cuda_graphs and not retain_pinned_cpu_buffers:
pytest.skip(
"Cuda graphs are not yet supported with cpu offloading when"
" retain_pinned_cpu_buffers is False."
)
# Reset memory if backend == "FusedAttention" and use_cuda_graphs:
gc.collect() pytest.skip(
torch.cuda.empty_cache() "Fused attention + cuda graphs is temporarily broken, not because of cpu offloading"
)
# Context and sync function for CPU offloading os.environ["NVTE_FLASH_ATTN"] = "0"
if cpu_offload: os.environ["NVTE_FUSED_ATTN"] = "0"
offload_context, sync_function = te.get_cpu_offload_context( os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
elif backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
elif backend == "UnfusedAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
offload_ctx, sync_function = get_cpu_offload_context(
enabled=True, enabled=True,
num_layers=len(modules), num_layers=1,
model_layers=len(modules) + 1, model_layers=2,
offload_activations=True, offload_activations=True,
offload_weights=False, offload_weights=False,
retain_pinned_cpu_buffers=retain_pinned_cpu_buffers,
)
class Callable(torch.nn.Module):
def __init__(self, offload_ctx=None, sync_function=None):
super().__init__()
self.layers = torch.nn.ModuleList(
[Utils.create_layer(layer_type) for _ in range(2)]
)
self.offload_ctx = offload_ctx
self.sync_function = sync_function
def forward(self, x):
m_splits = (
{"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H}
if layer_type == "grouped_linear"
else {}
) )
is_ops_layer = layer_type in ["linear_op", "layernorm_mlp_ops"]
for layer in self.layers:
with self.offload_ctx, recipe_ctx():
if is_ops_layer:
x = layer(x, **m_splits)
else: else:
offload_context = contextlib.nullcontext() x = layer(x, is_first_microbatch=False, **m_splits)
sync_function = lambda x: x if self.sync_function is not None:
x = self.sync_function(x)
# Forward pass, with dummy step to trigger offload for last module return x
inp = _make_input()
tensor = inp
memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
for module in modules:
with te.autocast(
enabled=quantization_recipe is not None, recipe=quantization_recipe
), offload_context:
tensor = module(tensor)
tensor = sync_function(tensor)
with offload_context:
tensor = tensor.clone()
tensor = sync_function(tensor)
memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2)
# Backward pass
tensor.sum().backward()
torch.cuda.synchronize()
# Memory usage in MiB callable_offload = Callable(offload_ctx=offload_ctx, sync_function=sync_function)
return memory_after_forward - memory_before_forward callable_no_offload = Callable(offload_ctx=contextlib.nullcontext(), sync_function=None)
# copy parameters
for param_offload, param_no_offload in zip(
callable_offload.parameters(), callable_no_offload.parameters()
):
param_offload.data.copy_(param_no_offload.data)
@pytest.mark.parametrize("quantization_recipe", quantization_recipes) x = Utils.create_tensor(None)
@pytest.mark.parametrize("model_name", model_types.keys())
def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None:
"""Check that CPU offloading runs and has expected memory usage."""
# Construct model if use_cuda_graphs:
modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)] callable_offload = te.make_graphed_callables(
if model_name in ["multihead_attention", "transformer_layer"]: callable_offload,
available_backends, *_ = get_available_attention_backends( (x,),
model_config["small"], enabled=recipe is not None,
qkv_dtype=torch.bfloat16, recipe=(Utils.create_recipe_ctx(recipe) if recipe is not None else None),
qkv_layout="sbhd_sbhd_sbhd",
) )
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("Fused attention backend not available.")
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
# Warmup # warm up (for example to compute sf for delayed scaling)
_warmup_model(modules_list, quantization_recipe) for _ in range(4):
out = callable_offload(x)
out.sum().backward()
out = callable_no_offload(x)
out.sum().backward()
callable_offload.zero_grad(set_to_none=True)
out_offload = callable_offload(x)
out_offload.sum().backward()
# save out and gradients
offload_outs = [out_offload]
for param in callable_offload.parameters():
offload_outs.append(param.detach().clone())
# Measure cached memory after forward pass torch.cuda.reset_peak_memory_stats()
memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False) out_no_offload = callable_no_offload(x)
memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True) out_no_offload.sum().backward()
# Check for expected memory usage # collect gradients
assert memory_with_offload < memory_without_offload no_offload_outs = [out_no_offload]
memory_from_cached_weights = _estimate_cached_weight_size( for param in callable_no_offload.parameters():
model_name, no_offload_outs.append(param.detach().clone())
modules_list,
quantization_recipe, # check if tensors are the same
for i in range(len(offload_outs)):
assert torch.allclose(offload_outs[i], no_offload_outs[i]), f"Error in tensor {i}."
torch.cuda.synchronize()
def test_example_from_doc(self):
offload_stream = torch.cuda.Stream()
num_layers = 10
layers = [Utils.create_layer("transformer_layer") for _ in range(num_layers)]
inp = [Utils.create_tensor(None) for _ in range(num_layers)]
out = [None] * num_layers
cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context(
enabled=True,
model_layers=num_layers,
manual_synchronization=True,
offload_stream=offload_stream,
) )
assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON
for i in range(num_layers):
with cpu_offload_context:
out[i] = layers[i].forward(inp[i])
out[i] = sync_function(out[i])
manual_controller.start_offload_layer(i)
offload_stream.synchronize()
for i in range(num_layers):
manual_controller.release_activation_forward_gpu_memory(i)
for i in range(num_layers - 1, -1, -1):
# these calls are intended to be done in the backward pass
manual_controller.start_reload_layer(i)
offload_stream.synchronize()
for i in range(num_layers):
out[i].sum().backward()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import contextlib
import gc
import os
from typing import Iterable, Optional
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
from utils import ModelConfig, get_available_attention_backends
# Check supported quantization schemes
fp8_available = te.is_fp8_available()
mxfp8_available = te.is_mxfp8_available()
quantization_recipes: Optional[recipe.Recipe] = [None]
if fp8_available:
quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling()))
model_config = {
"small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
}
SIZE = model_config["small"].hidden_size
NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps
# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
assert os.getenv("NVTE_FLASH_ATTN") == "0"
# CPU offload v1 code path is enabled
assert os.environ.get("NVTE_CPU_OFFLOAD_V1", "0") == "1"
# Offloading is supported for attention only for fused and flash attention backends,
# so the use of bfloat16 is required.
#
# For the TransformerLayer, activation offloading with dropout is not supported,
# so we set hidden_dropout to 0.0.
model_types = {
"linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16),
"multihead_attention": lambda: te.MultiheadAttention(
SIZE, NUM_HEADS, params_dtype=torch.bfloat16
),
"transformer_layer": lambda: te.TransformerLayer(
SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
),
"linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
"layernorm_mlp_ops": lambda: te.ops.Sequential(
te.ops.LayerNorm(SIZE, dtype=torch.bfloat16),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
te.ops.GELU(),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
),
}
def _make_input() -> torch.Tensor:
"""Generate random input tensor."""
return torch.randn(
(128, SIZE, SIZE),
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
def _warmup_model(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> None:
"""Perform forward and backward pass"""
tensor = _make_input()
for module in modules:
with te.autocast(
enabled=quantization_recipe is not None,
recipe=quantization_recipe,
):
tensor = module(tensor)
tensor.sum().backward()
def _estimate_cached_weight_size(
model_name: str,
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> float:
"""Calculate the memory (in MiB) needed for weight caching."""
# The weight params are cached directly for unquantized compute
if quantization_recipe is None:
return 0
# Count number of weight param elements
param_elements = 0
for module in modules:
for param in module.parameters():
if param.dim() == 2:
param_elements += param.numel()
# FP8 tensor-scaling caches one byte per element
if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling():
if not is_non_tn_fp8_gemm_supported() and model_name not in (
"linear_op",
"layernorm_mlp_ops",
):
# Modules do not deallocate FP8 transpose for weights
return 2 * param_elements / 1024**2
return param_elements / 1024**2
# MXFP8 caches one data byte per element and one scale byte per 32
# elements
if quantization_recipe.mxfp8():
if model_name not in ("linear_op", "layernorm_mlp_ops"):
# Modules do not deallocate column-wise MXFP8 data for weights
return 2 * param_elements * (1 + 1 / 32) / 1024**2
return param_elements * (1 + 1 / 32) / 1024**2
raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})")
def _measure_cached_memory(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
cpu_offload: bool,
) -> float:
"""Measure the growth in allocated GPU memory in MiB after a model forward pass.
Memory measurement excludes the input and output tensors.
"""
# Reset memory
gc.collect()
torch.cuda.empty_cache()
# Context and sync function for CPU offloading
if cpu_offload:
offload_context, sync_function = te.get_cpu_offload_context(
enabled=True,
num_layers=len(modules),
model_layers=len(modules) + 1,
offload_activations=True,
offload_weights=False,
)
else:
offload_context = contextlib.nullcontext()
sync_function = lambda x: x
# Forward pass, with dummy step to trigger offload for last module
inp = _make_input()
tensor = inp
memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
for module in modules:
with te.autocast(
enabled=quantization_recipe is not None, recipe=quantization_recipe
), offload_context:
tensor = module(tensor)
tensor = sync_function(tensor)
with offload_context:
tensor = tensor.clone()
tensor = sync_function(tensor)
memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2)
# Backward pass
tensor.sum().backward()
torch.cuda.synchronize()
# Memory usage in MiB
return memory_after_forward - memory_before_forward
@pytest.mark.parametrize("quantization_recipe", quantization_recipes)
@pytest.mark.parametrize("model_name", model_types.keys())
def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None:
"""Check that CPU offloading runs and has expected memory usage."""
# Construct model
modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)]
if model_name in ["multihead_attention", "transformer_layer"]:
available_backends, *_ = get_available_attention_backends(
model_config["small"],
qkv_dtype=torch.bfloat16,
qkv_layout="sbhd_sbhd_sbhd",
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("Fused attention backend not available.")
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
# Warmup
_warmup_model(modules_list, quantization_recipe)
# Measure cached memory after forward pass
memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False)
memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True)
# Check for expected memory usage
assert memory_with_offload < memory_without_offload
memory_from_cached_weights = _estimate_cached_weight_size(
model_name,
modules_list,
quantization_recipe,
)
assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from typing import Iterable, List, Union from typing import Callable, Dict, Iterable, List, Tuple, Union
import pytest import pytest
import torch import torch
...@@ -173,6 +173,20 @@ def get_outputs( ...@@ -173,6 +173,20 @@ def get_outputs(
return values return values
def reset_graphs(
graphed_callables: Union[Callable, Tuple[Callable, ...], Dict[Tuple[int, int], Callable]],
) -> None:
"""Reset CUDA graphs."""
if isinstance(graphed_callables, tuple) or isinstance(graphed_callables, list):
for callable in graphed_callables:
callable.reset()
elif isinstance(graphed_callables, dict):
for callable in graphed_callables.values():
callable.reset()
else:
graphed_callables.reset()
class _Sequential(torch.nn.Sequential): class _Sequential(torch.nn.Sequential):
"""Sequential model that forwards keyword arguments to modules""" """Sequential model that forwards keyword arguments to modules"""
...@@ -335,7 +349,12 @@ def _test_cuda_graphs( ...@@ -335,7 +349,12 @@ def _test_cuda_graphs(
output.backward(grad_output) output.backward(grad_output)
optimizer.step() optimizer.step()
return get_outputs(model, output) outputs = get_outputs(model, output)
if graph_mode == "full":
reset_graphs(model)
elif graph_mode == "individual":
reset_graphs(modules)
return outputs
@pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("module", _test_cuda_graphs_modules)
...@@ -487,7 +506,10 @@ def _test_cuda_graphs_with_dot_product_attention( ...@@ -487,7 +506,10 @@ def _test_cuda_graphs_with_dot_product_attention(
output = model(*inputs) output = model(*inputs)
output.backward(grad_output) output.backward(grad_output)
return get_outputs(model, output) outputs = get_outputs(model, output)
if with_graph:
reset_graphs(model)
return outputs
@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dtype", dtypes)
...@@ -572,7 +594,10 @@ def _test_cuda_graphs_with_kwargs( ...@@ -572,7 +594,10 @@ def _test_cuda_graphs_with_kwargs(
output.backward(grad_output) output.backward(grad_output)
optimizer.step() optimizer.step()
return get_outputs(model, output) outputs = get_outputs(model, output)
if with_graph:
reset_graphs(model)
return outputs
def test_make_graphed_callables_with_kwargs( def test_make_graphed_callables_with_kwargs(
...@@ -687,7 +712,10 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( ...@@ -687,7 +712,10 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
optimizer.step() optimizer.step()
outputs = [y for _, y in sorted(outputs.items())] outputs = [y for _, y in sorted(outputs.items())]
return get_outputs(model, outputs) outputs = get_outputs(model, outputs)
if with_graph:
reset_graphs(layer_forwards)
return outputs
def test_make_graphed_callables_with_interleaved_pipeline_parallelism( def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment