Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
......@@ -1604,16 +1604,18 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):
@contextmanager
def use_jax_gemm(enabled=False):
orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None)
orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS", None)
try:
if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=false"
else:
os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=true"
yield
finally:
if enabled:
if orig_custom_calls_filter is None:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE")
os.environ.pop("NVTE_JAX_CUSTOM_CALLS")
else:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter
os.environ["NVTE_JAX_CUSTOM_CALLS"] = orig_custom_calls_filter
......@@ -13,7 +13,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel
get_cu_seqlens_on_cp_rank,
)
import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.common.recipe import DelayedScaling
......
......@@ -4,6 +4,8 @@
import os
import subprocess
import sys
import pathlib
import pytest
import torch
......@@ -12,27 +14,29 @@ from transformer_engine.pytorch.utils import (
get_cudnn_version,
)
from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils
from test_fused_attn import ModelConfig
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig, get_available_attention_backends
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
from torch.utils.cpp_extension import IS_HIP_EXTENSION
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_1_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
"cp_1_3": ModelConfig(2, 4096, 12, 128, window_size=(512, 512)), # MHA
"cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
"cp_2_2": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
"cp_2_3": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
), # GQA
"cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA
}
......@@ -44,7 +48,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
"--nproc-per-node=" + str(num_gpus_per_node),
]
te_path = os.getenv("TE_PATH", "/opt/transformerengine")
script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py")
script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py")
args.append(script_path)
for k, v in kwargs.items():
args.append(f"{k}={v}")
......@@ -94,37 +98,41 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_1_4": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
"cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA
"cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
"cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
"cp_2_2": ModelConfig(
2,
4096,
12,
128,
num_gqa_groups=2,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
), # GQA
"cp_2_3": ModelConfig(
2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias"
), # GQA
"cp_2_4": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
), # GQA
"cp_3_0": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # MLA
"cp_3_1": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64
), # MLA
"cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA
"cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA
"cp_3_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(IS_HIP_EXTENSION or get_device_compute_capability() < (8, 0), reason="DTK not surpport fused attn for now, CP tests require sm80+.")
@pytest.mark.skipif(IS_HIP_EXTENSION or get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
......@@ -176,6 +184,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!")
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype],
qkv_layout="_".join([qkv_format] * 3),
window_size=config.window_size,
context_parallel=True,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("No attention backend available.")
subprocess.run(
get_bash_arguments(
......
......@@ -5,18 +5,14 @@
from collections import OrderedDict
from typing import List
import os
import sys
import pathlib
import logging
import math
import pytest
import torch
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.common import recipe
......@@ -34,26 +30,25 @@ from transformer_engine.pytorch.utils import (
is_bf16_compatible,
)
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import (
ModelConfig,
reset_rng_states,
get_available_attention_backends,
)
# Reset RNG states
reset_rng_states()
param_types = [torch.float16]
if is_bf16_compatible():
param_types.append(torch.bfloat16)
model_configs_infer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"infer_0": ModelConfig(
4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16
),
"infer_1": ModelConfig(
2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
),
# test: b, sq, hq, dqk,
"infer_0": ModelConfig(4, 64, 16, 128, total_requests=8, max_ctx_len=16),
"infer_1": ModelConfig(2, 66, 16, 256, num_gqa_groups=4, total_requests=6, max_ctx_len=16),
}
qkv_formats = ["bshd", "sbhd", "thd"]
......@@ -470,7 +465,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2)
if is_paged:
qkv_layout = "paged_kv_" + qkv_layout
available_backends, _, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
......
......@@ -364,6 +364,40 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test
def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs):
from test_log import LOG_QUANTIZED_CONFIG
kwargs["config_file"].write(LOG_QUANTIZED_CONFIG)
kwargs["config_file"].flush()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
set_weight_tensor_tp_group_reduce(gather_weight)
if WORLD_SIZE % 2 != 0:
return # skip
TP_SIZE = WORLD_SIZE // 2
DP_SIZE = 2
TP_RANK = WORLD_RANK % TP_SIZE
DP_RANK = (WORLD_RANK - TP_RANK) // TP_SIZE
debug_api.set_tensor_reduction_group(NCCL_WORLD)
x, weight = _get_tensors(
parallel_mode,
weight_seed=TP_RANK * 1234,
data_seed=DP_RANK * 1234,
tp_size=TP_SIZE,
tp_rank=TP_RANK,
)
tp_group_ranks = [i for i in range(DP_RANK * TP_SIZE, (DP_RANK + 1) * TP_SIZE)]
tp_group = dist.new_group(ranks=tp_group_ranks)
model = _init_model(weight, parallel_mode=parallel_mode, tp_group=tp_group)
_run_forward_backward(x, model, parallel_mode=parallel_mode, group=tp_group)
set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test
def test_log_expert_parallel(**kwargs):
"""
......
......@@ -24,22 +24,17 @@ def test_transformer_engine_no_config(feature_dirs):
# FP8 enabled - true by the default
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
)[0]
# modify_tensor_enabled - False by default
# modify_tensor_enabled - (False, None) by default
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)
)[0]
# inspect_tensor_enabled - False by default
# inspect_tensor_enabled - (False, None) by default
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.attn.qkv", tensor_name="activation", iteration=0
)
# inspect_tensor_postquantize - False by default
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)
)[0]
finally:
debug_api.end_debug()
......@@ -51,24 +46,24 @@ def test_disable_fp8_gemm(configs_dir, feature_dirs):
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
)[0]
# caching
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
)[0]
finally:
debug_api.end_debug()
......@@ -80,22 +75,22 @@ def test_disable_fp8_layer(configs_dir, feature_dirs):
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="fprop", iteration=0
)
)[0]
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", iteration=0
)
)[0]
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
)[0]
finally:
debug_api.end_debug()
......@@ -111,22 +106,22 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
# check modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
)
)[0]
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0
)
)[0]
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
)
)[0]
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0
)
)[0]
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0
)
)[0]
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0
)
)[0]
# check modify_tensor
......@@ -168,14 +163,14 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
gemm="wgrad",
tensor_name="gradient",
iteration=0,
)
)[0]
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc4",
gemm="fprop",
tensor_name="activation",
iteration=0,
)
)[0]
finally:
debug_api.end_debug()
......@@ -191,11 +186,11 @@ def test_fake_quant(configs_dir, feature_dirs):
# modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
)
)[0]
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
)
)[0]
# modify_tensor
debug_api.transformer_engine.modify_tensor(
......@@ -218,11 +213,11 @@ def test_fake_quant(configs_dir, feature_dirs):
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0
)
)[0]
# caching
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0
)
)[0]
finally:
debug_api.end_debug()
......@@ -236,13 +231,12 @@ def test_statistics_collection(configs_dir, feature_dirs):
)
tensor = torch.randn((100, 100, 5)).cuda()
tensor_fp8 = Float8Tensor(
data=tensor.to(torch.uint8).cuda(),
fp8_scale_inv=torch.full([1], 1.0).cuda(),
quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=tensor.shape,
dtype=torch.float32,
)
tensor_fp8 = quantizer(tensor)
def log():
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
......@@ -260,54 +254,64 @@ def test_statistics_collection(configs_dir, feature_dirs):
tensor_name="activation",
iteration=200,
tp_group=None,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
)
stats = log()
assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max()
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", iteration=201
)
)[0]
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.2.mlp.fc1", tensor_name="activation", iteration=200
)
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
)[0]
expected_underflows = (
((tensor_fp8._data == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5)
)
expected_underflows = (tensor_fp8._data == 0).sum() * 100 / (100 * 100 * 5)
expected_overflows = (tensor_fp8._data == 126).sum() * 100 / (100 * 100 * 5)
assert debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
)[0]
# TE FP8 tensor stats --
assert debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
)
debug_api.transformer_engine.inspect_tensor_postquantize(
assert debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
)[0]
debug_api.transformer_engine.inspect_tensor(
"decoder.1.mlp.fc1",
tensor=tensor_fp8,
tensor_name="gradient",
iteration=200,
rowwise=True,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
)
stats = log()
torch.testing.assert_close(
stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows
)
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", gemm="fprop", iteration=201
)
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.2.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
)
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", iteration=201
)[0]
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.2.mlp.fc1", tensor_name="gradient", iteration=200
)[0]
# Second config in same yaml
tensor = torch.rand((100, 100, 5))
debug_api.transformer_engine.inspect_tensor(
"decoder.6.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=200,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
)
stats = log()
stats_names = [x[3] for x in stats.keys()]
......@@ -316,10 +320,13 @@ def test_statistics_collection(configs_dir, feature_dirs):
debug_api.transformer_engine.inspect_tensor(
"decoder.7.mlp.fc1",
tensor=tensor,
tensor_name="weight",
iteration=200,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
)
stats = log()
stats_names = [x[3] for x in stats.keys()]
......@@ -328,7 +335,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.7.mlp.fc1", tensor_name="weight", iteration=201
)
)[0]
assert_empty()
finally:
......@@ -343,21 +350,16 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
default_logging_enabled=False,
)
def feed(tensor, tensor_fp8):
def feed(tensor, tensor_fp8, quantizer):
debug_api.transformer_engine.inspect_tensor(
"decoder.5.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=1,
tp_group=None,
)
debug_api.transformer_engine.inspect_tensor_postquantize(
"decoder.5.mlp.fc1",
tensor=tensor_fp8,
tensor_name="activation",
iteration=1,
rowwise=True,
tp_group=None,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
)
def log_stats():
......@@ -365,26 +367,26 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
return STATS_BUFFERS.log_stats()
def fp8_tensor(t):
return Float8Tensor(
data=t.to(torch.uint8).cuda(),
fp8_scale_inv=torch.ones([1]).cuda(),
quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=t.shape,
dtype=torch.float32,
)
def fp8_tensor(t):
return quantizer(t.cuda())
shape = [1024, 1024]
tensors = [torch.randn(shape) for _ in range(2)]
tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)]
feed(tensors[0], tensors_fp8[0])
feed(tensors[1], tensors_fp8[1])
feed(tensors[0], tensors_fp8[0], quantizer)
feed(tensors[1], tensors_fp8[1], quantizer)
stats1 = log_stats()
tensor2 = torch.cat((tensors[0], tensors[1])).cuda()
fp8tensor2 = fp8_tensor(tensor2)
feed(tensor2, fp8tensor2)
feed(tensor2, fp8tensor2, quantizer)
stats2 = log_stats()
assert len(stats1.keys()) > 0
......
test:
enabled: True
layers:
layer_name_regex_pattern: .*
transformer_engine:
LogTensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [cur_amax, dynamic_range, mean, std, l1_norm]
start_step: 1
freq: 3
LogFp8TensorStats:
enabled: True
tensors: activation
stats: [underflows%]
start_step: 1
freq: 5
\ No newline at end of file
test:
enabled: True
layers:
layer_name_regex_pattern: .*1
transformer_engine:
LogTensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [cur_amax, dynamic_range, mean, std, l1_norm]
start_step: 0
freq: 100000
\ No newline at end of file
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import nvdlfw_inspect.api as debug_api
import transformer_engine.debug
import transformer_engine.pytorch as te
import torch
import tempfile
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import RecipeState
import pytest
import contextlib
import os
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.debug.pytorch.debug_state import TEDebugState
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
LOG_QUANTIZED_CONFIG_BASE = """
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogFp8TensorStats:
enabled: True
stats: [
{stats}
]
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
recipes = [
"fp8_delayed_scaling",
"fp8_current_scaling",
"fp8_block_scaling",
"mxfp8",
]
bare_stats = [
"underflows%",
"scale_inv_min",
"scale_inv_max",
"mse",
]
all_stats = []
for r in recipes:
for stat in bare_stats:
for columnwise_postfix in ["", "_columnwise"]:
if (
r in ["fp8_current_scaling", "fp8_block_scaling"]
and torch.cuda.get_device_capability()[0] < 9
):
# hopper is needed for current-scaling, block-scaling
continue
if r == "mxfp8" and torch.cuda.get_device_capability()[0] < 10:
# blackwell is needed for mxfp8
continue
if (
r in ["fp8_delayed_scaling", "fp8_current_scaling"]
and columnwise_postfix == "_columnwise"
):
# columnwise stats are not supported for fp8_delayed_scaling and fp8_current_scaling
continue
all_stats.append(f"{r}_{stat}{columnwise_postfix}")
all_stats.append("fp8_delayed_scaling_overflows%") # only delayed-scaling supports overflows%
@contextlib.contextmanager
def debug_session(config_str: str, feature_dirs):
"""
Helper context manager that
1. writes the YAML `config_str` to a temporary file,
2. starts a debug session, and
3. yields the directory that contains the statistics log.
The session is closed automatically – even on exceptions – so every test
stays concise and leak-free.
"""
with tempfile.NamedTemporaryFile(
mode="w", delete=False
) as cfg_file, tempfile.TemporaryDirectory() as log_dir:
cfg_file.write(config_str)
cfg_file.flush()
debug_api.initialize(
config_file=cfg_file.name,
feature_dirs=feature_dirs,
log_dir=log_dir,
)
try:
yield log_dir
finally:
debug_api.end_debug()
def read_log(log_dir: str) -> str:
"""Return the content of the statistics log produced by `debug_session`."""
stat_path = os.path.join(
log_dir,
"nvdlfw_inspect_statistics_logs",
"nvdlfw_inspect_globalrank-0.log",
)
with open(stat_path, "r") as f:
return f.read()
def test_sanity(feature_dirs):
log_all_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(all_stats))
with debug_session(log_all_stats_config, feature_dirs) as log_dir:
model = te.Linear(128, 128, params_dtype=torch.bfloat16)
inp = torch.zeros(128, 128, dtype=torch.bfloat16).cuda()
for _ in range(10):
with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()):
output = model(inp)
loss = output.sum()
loss.backward()
debug_api.step()
output = read_log(log_dir)
assert output, "Output is empty"
for stat in all_stats:
assert stat in output, f"Stat {stat} not found in output"
fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
]
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_numerics(fp8_recipe, feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling():
pytest.skip(reason_for_no_mxfp8)
if not fp8_block_scaling_available and fp8_recipe == recipe.Float8BlockScaling():
pytest.skip(reason_for_no_fp8_block_scaling)
log_only_bare_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(bare_stats))
with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir:
recipe_state = RecipeState.create(
fp8_recipe,
mode="forward",
num_quantizers=3,
)
tensor = torch.zeros(1024, 1024).cuda()
tensor[0, :] = 1000
quantizer = recipe_state.make_quantizers()[0]
quantized_tensor = quantizer(tensor)
debug_api.transformer_engine.inspect_tensor(
layer_name="layer_name",
tensor_name="activation",
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=quantized_tensor,
columnwise_quantized_tensor=quantized_tensor,
)
debug_api.step()
dequantized_tensor = quantized_tensor.dequantize()
output = read_log(log_dir)
for line in output.splitlines():
if "underflows%" in line:
underflows = float(line.split("value=")[1])
expected = (
((dequantized_tensor == 0).sum() - (tensor == 0).sum())
/ dequantized_tensor.numel()
* 100
)
assert underflows == pytest.approx(expected.cpu(), abs=1e-4)
if "mse" in line:
mse = float(line.split("value=")[1])
expected = torch.nn.functional.mse_loss(dequantized_tensor, tensor, reduction="mean")
assert mse == pytest.approx(expected.cpu(), abs=1e-6)
if "overflows%" in line:
overflows = float(line.split("value=")[1])
expected = (
(abs(dequantized_tensor) > abs(tensor)).sum() / dequantized_tensor.numel() * 100
)
assert overflows == pytest.approx(expected.cpu(), abs=1e-4)
@pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
# If layer does not invoke any feature in current iteration,
# then it changed into non-debug mode.
# This test checks whether this works correctly -
# non-quantized statistics should be logged every 3 iterations,
# and quantized statistics should be logged every 5 iterations.
with tempfile.TemporaryDirectory() as temp_dir:
debug_api.initialize(
config_file=configs_dir + "/log_config.yaml",
feature_dirs=feature_dirs,
log_dir=temp_dir,
)
if layer == "linear":
model = te.Linear(128, 128, name="linear1")
elif layer == "transformer":
model = te.TransformerLayer(128, 128, 4, name="transformer1")
else:
raise ValueError(f"Invalid layer: {layer}")
for i in range(20):
x = torch.randn(4, 128, 128).cuda()
with te.fp8_autocast(enabled=True):
y = model(x)
y.sum().backward()
debug_api.step()
with open(
os.path.join(
temp_dir, "nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log"
),
"r",
) as f:
file_content = f.read()
for i in range(1, 20):
if i % 3 == 0 or i % 5 == 0:
assert f"iteration={i:06d}" in file_content
else:
assert f"iteration={i:06d}" not in file_content
debug_api.end_debug()
TEDebugState._reset()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
import time
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug.pytorch.debug_state import TEDebugState
def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs):
debug_api.end_debug()
TEDebugState._reset()
if debug_tools_initialized:
# This config log stats starting from 0, every N iterations for huge N >> NUM_ITERS.
# So after 1 warm-up iteration, this layers should work in non-debug mode.
debug_api.initialize(
config_file=configs_dir + "/perf_config.yaml", feature_dirs=feature_dirs
)
try:
if layer == "linear":
model = torch.nn.Sequential(
te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2")
).cuda()
NUM_ITERS = 18000
elif layer == "transformer":
model = torch.nn.Sequential(
te.TransformerLayer(1, 1, 1, name="transformer1"),
te.TransformerLayer(1, 1, 1, name="transformer2"),
).cuda()
NUM_ITERS = 2000
x = torch.randn(1, 1, 1).cuda()
y = model(x)
y.sum().backward()
debug_api.step()
torch.cuda.synchronize()
time_start = time.time()
for i in range(NUM_ITERS):
y = model(x)
y.sum().backward()
if debug_tools_initialized:
debug_api.step()
torch.cuda.synchronize()
time_end = time.time()
finally:
if debug_tools_initialized:
debug_api.end_debug()
return time_end - time_start
@pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_cpu_overhead(layer, configs_dir, feature_dirs):
# runs one layer many times on very small tensor
# - gpu time should be negligible, so time should be dominated by cpu time.
# if layers does not invoke any feature in current iteration,
# then it changed into non-debug mode and should not have any non-negligible cpu overhead
# compared to layer without debug tools initialized.
with_debug_tools = _run_cpu_overhead(True, layer, configs_dir, feature_dirs)
without_debug_tools = _run_cpu_overhead(False, layer, configs_dir, feature_dirs)
print(f"with_debug_tools: {with_debug_tools} s")
print(f"without_debug_tools: {without_debug_tools} s")
assert with_debug_tools < without_debug_tools * 1.25 # 25% overhead margin
......@@ -519,6 +519,7 @@ def _train(opts):
if opts.use_cuda_graphs:
del test_graph
torch.cuda.synchronize()
te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pathlib
import sys
import pytest
import torch
import transformer_engine
from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention
from transformer_engine.pytorch import TransformerLayer, Linear
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig
model_configs = {
"small": ModelConfig(2, 10, 2, 16),
}
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention", "Linear"])
def test_current_device(model, module):
"""Test cases where current device is different from tensor device"""
num_devices = torch.cuda.device_count()
assert num_devices > 1, "This test requires more than one GPU!"
tensor_device = num_devices - 1
dtype = torch.bfloat16
config = model_configs[model]
args = []
kwargs = {}
bwd_args = []
if module == "TransformerLayer":
model = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_heads,
params_dtype=dtype,
attn_input_format="thd",
self_attn_mask_type="padding",
device=f"cuda:{tensor_device}",
)
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item()
args = [
torch.randn(
(num_tokens, config.hidden_size),
dtype=dtype,
device=f"cuda:{tensor_device}",
requires_grad=True,
)
]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv
if module == "DotProductAttention":
model = DotProductAttention(
config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding"
)
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item()
args = [
torch.randn(
num_tokens,
config.num_heads,
config.head_dim_qk,
dtype=dtype,
device=tensor_device,
requires_grad=True,
)
for _ in range(3)
]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv
bwd_args = [torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=tensor_device)]
elif module == "Linear":
model = Linear(
config.hidden_size,
4 * config.hidden_size,
params_dtype=dtype,
device=f"cuda:{tensor_device}",
)
args = [
torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device=f"cuda:{tensor_device}",
requires_grad=True,
)
]
current_device_before = torch.cuda.current_device()
out = model(*args, **kwargs)
if module == "DotProductAttention":
out.backward(*bwd_args)
else:
loss = out.sum()
loss.backward()
current_device_after = torch.cuda.current_device()
tensor_device_out = out.get_device()
tensor_device_grad = args[0].grad.get_device()
assert (
current_device_after == current_device_before
), "The current device should not have changed!"
assert (
tensor_device_out == tensor_device
), "The output tensor should be the same as the input tensors!"
assert (
tensor_device_grad == tensor_device
), "The gradient tensor should be the same as the input tensors!"
This diff is collapsed.
This diff is collapsed.
......@@ -223,7 +223,7 @@ class TestFloat8BlockwiseTensor:
rowwise=True,
columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim,
all_gather_usage=True,
all_gather_usage=(block_scaling_dim == 1),
)
self._test_quantize_dequantize(
quantizer=quantizer,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment