Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
...@@ -1604,16 +1604,18 @@ def print_debug_tensor_stats(prefix, tensor, hist=False): ...@@ -1604,16 +1604,18 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):
@contextmanager @contextmanager
def use_jax_gemm(enabled=False): def use_jax_gemm(enabled=False):
orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None) orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS", None)
try: try:
if enabled: if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=false"
else:
os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=true"
yield yield
finally: finally:
if enabled: if enabled:
if orig_custom_calls_filter is None: if orig_custom_calls_filter is None:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") os.environ.pop("NVTE_JAX_CUSTOM_CALLS")
else: else:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter os.environ["NVTE_JAX_CUSTOM_CALLS"] = orig_custom_calls_filter
...@@ -13,7 +13,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel ...@@ -13,7 +13,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel
get_cu_seqlens_on_cp_rank, get_cu_seqlens_on_cp_rank,
) )
import transformer_engine_torch as tex import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
import os import os
import subprocess import subprocess
import sys
import pathlib
import pytest import pytest
import torch import torch
...@@ -12,27 +14,29 @@ from transformer_engine.pytorch.utils import ( ...@@ -12,27 +14,29 @@ from transformer_engine.pytorch.utils import (
get_cudnn_version, get_cudnn_version,
) )
from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils
from test_fused_attn import ModelConfig
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig, get_available_attention_backends
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
model_configs_flash_attn = { model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig( "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) "cp_1_3": ModelConfig(2, 4096, 12, 128, window_size=(512, 512)), # MHA
), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_1_3": ModelConfig( "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig( "cp_2_2": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
), # GQA
"cp_2_3": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # GQA ), # GQA
"cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA
} }
...@@ -44,7 +48,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): ...@@ -44,7 +48,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
"--nproc-per-node=" + str(num_gpus_per_node), "--nproc-per-node=" + str(num_gpus_per_node),
] ]
te_path = os.getenv("TE_PATH", "/opt/transformerengine") te_path = os.getenv("TE_PATH", "/opt/transformerengine")
script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py") script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py")
args.append(script_path) args.append(script_path)
for k, v in kwargs.items(): for k, v in kwargs.items():
args.append(f"{k}={v}") args.append(f"{k}={v}")
...@@ -94,37 +98,41 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -94,37 +98,41 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn = { model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA "cp_1_2": ModelConfig(
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
"cp_1_4": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA ), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
"cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
"cp_2_2": ModelConfig(
2,
4096,
12,
128,
num_gqa_groups=2,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
), # GQA
"cp_2_3": ModelConfig(
2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias"
), # GQA
"cp_2_4": ModelConfig( "cp_2_4": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
), # GQA ), # GQA
"cp_3_0": ModelConfig( "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64 "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA
), # MLA
"cp_3_1": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64
), # MLA
"cp_3_2": ModelConfig( "cp_3_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64
), # MLA ), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
} }
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(IS_HIP_EXTENSION or get_device_compute_capability() < (8, 0), reason="DTK not surpport fused attn for now, CP tests require sm80+.") @pytest.mark.skipif(IS_HIP_EXTENSION or get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) @pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
...@@ -176,6 +184,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -176,6 +184,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("MLA CP currently only support KV P2P!") pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!") pytest.skip("MLA CP currently does not support FP8 attention!")
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype],
qkv_layout="_".join([qkv_format] * 3),
window_size=config.window_size,
context_parallel=True,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("No attention backend available.")
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
......
...@@ -5,18 +5,14 @@ ...@@ -5,18 +5,14 @@
from collections import OrderedDict from collections import OrderedDict
from typing import List from typing import List
import os import os
import sys
import pathlib
import logging import logging
import math import math
import pytest import pytest
import torch import torch
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
from torch.distributions import Exponential from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.common import recipe from transformer_engine.common import recipe
...@@ -34,26 +30,25 @@ from transformer_engine.pytorch.utils import ( ...@@ -34,26 +30,25 @@ from transformer_engine.pytorch.utils import (
is_bf16_compatible, is_bf16_compatible,
) )
# Initialize RNG state _current_file = pathlib.Path(__file__).resolve()
seed = 1234 sys.path.append(str(_current_file.parent.parent))
torch.manual_seed(seed) from utils import (
torch.cuda.manual_seed(seed) ModelConfig,
_cpu_rng_state = torch.get_rng_state() reset_rng_states,
_cuda_rng_state = torch.cuda.get_rng_state() get_available_attention_backends,
)
# Reset RNG states
reset_rng_states()
param_types = [torch.float16] param_types = [torch.float16]
if is_bf16_compatible(): if is_bf16_compatible():
param_types.append(torch.bfloat16) param_types.append(torch.bfloat16)
model_configs_infer = { model_configs_infer = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, sq, hq, dqk,
"infer_0": ModelConfig( "infer_0": ModelConfig(4, 64, 16, 128, total_requests=8, max_ctx_len=16),
4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 "infer_1": ModelConfig(2, 66, 16, 256, num_gqa_groups=4, total_requests=6, max_ctx_len=16),
),
"infer_1": ModelConfig(
2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
),
} }
qkv_formats = ["bshd", "sbhd", "thd"] qkv_formats = ["bshd", "sbhd", "thd"]
...@@ -470,7 +465,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g ...@@ -470,7 +465,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2) qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2)
if is_paged: if is_paged:
qkv_layout = "paged_kv_" + qkv_layout qkv_layout = "paged_kv_" + qkv_layout
available_backends, _, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
......
...@@ -364,6 +364,40 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs): ...@@ -364,6 +364,40 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
set_weight_tensor_tp_group_reduce(True) # reset set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test
def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs):
from test_log import LOG_QUANTIZED_CONFIG
kwargs["config_file"].write(LOG_QUANTIZED_CONFIG)
kwargs["config_file"].flush()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
set_weight_tensor_tp_group_reduce(gather_weight)
if WORLD_SIZE % 2 != 0:
return # skip
TP_SIZE = WORLD_SIZE // 2
DP_SIZE = 2
TP_RANK = WORLD_RANK % TP_SIZE
DP_RANK = (WORLD_RANK - TP_RANK) // TP_SIZE
debug_api.set_tensor_reduction_group(NCCL_WORLD)
x, weight = _get_tensors(
parallel_mode,
weight_seed=TP_RANK * 1234,
data_seed=DP_RANK * 1234,
tp_size=TP_SIZE,
tp_rank=TP_RANK,
)
tp_group_ranks = [i for i in range(DP_RANK * TP_SIZE, (DP_RANK + 1) * TP_SIZE)]
tp_group = dist.new_group(ranks=tp_group_ranks)
model = _init_model(weight, parallel_mode=parallel_mode, tp_group=tp_group)
_run_forward_backward(x, model, parallel_mode=parallel_mode, group=tp_group)
set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test @run_debug_test
def test_log_expert_parallel(**kwargs): def test_log_expert_parallel(**kwargs):
""" """
......
...@@ -24,22 +24,17 @@ def test_transformer_engine_no_config(feature_dirs): ...@@ -24,22 +24,17 @@ def test_transformer_engine_no_config(feature_dirs):
# FP8 enabled - true by the default # FP8 enabled - true by the default
assert debug_api.transformer_engine.fp8_gemm_enabled( assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0 "decoder.1.attn.qkv", gemm="fprop", iteration=0
) )[0]
# modify_tensor_enabled - False by default # modify_tensor_enabled - (False, None) by default
assert not debug_api.transformer_engine.modify_tensor_enabled( assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0 "decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
) )[0]
# inspect_tensor_enabled - False by default # inspect_tensor_enabled - (False, None) by default
assert not debug_api.transformer_engine.inspect_tensor_enabled( assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.attn.qkv", tensor_name="activation", iteration=0 "decoder.1.attn.qkv", tensor_name="activation", iteration=0
) )[0]
# inspect_tensor_postquantize - False by default
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)
finally: finally:
debug_api.end_debug() debug_api.end_debug()
...@@ -51,24 +46,24 @@ def test_disable_fp8_gemm(configs_dir, feature_dirs): ...@@ -51,24 +46,24 @@ def test_disable_fp8_gemm(configs_dir, feature_dirs):
assert debug_api.transformer_engine.fp8_gemm_enabled( assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0 "decoder.1.attn.qkv", gemm="fprop", iteration=0
) )[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled( assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0 "decoder.1.attn.qkv", gemm="dgrad", iteration=0
) )[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled( assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0 "decoder.1.attn.qkv", gemm="wgrad", iteration=0
) )[0]
# caching # caching
assert debug_api.transformer_engine.fp8_gemm_enabled( assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0 "decoder.1.attn.qkv", gemm="fprop", iteration=0
) )[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled( assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0 "decoder.1.attn.qkv", gemm="dgrad", iteration=0
) )[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled( assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0 "decoder.1.attn.qkv", gemm="wgrad", iteration=0
) )[0]
finally: finally:
debug_api.end_debug() debug_api.end_debug()
...@@ -80,22 +75,22 @@ def test_disable_fp8_layer(configs_dir, feature_dirs): ...@@ -80,22 +75,22 @@ def test_disable_fp8_layer(configs_dir, feature_dirs):
assert debug_api.transformer_engine.fp8_gemm_enabled( assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="fprop", iteration=0 "decoder.1.mlp.fc1", gemm="fprop", iteration=0
) )[0]
assert debug_api.transformer_engine.fp8_gemm_enabled( assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", iteration=0 "decoder.1.mlp.fc1", gemm="wgrad", iteration=0
) )[0]
assert debug_api.transformer_engine.fp8_gemm_enabled( assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", iteration=0 "decoder.1.mlp.fc1", gemm="dgrad", iteration=0
) )[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled( assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0 "decoder.1.attn.qkv", gemm="fprop", iteration=0
) )[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled( assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0 "decoder.1.attn.qkv", gemm="wgrad", iteration=0
) )[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled( assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0 "decoder.1.attn.qkv", gemm="dgrad", iteration=0
) )[0]
finally: finally:
debug_api.end_debug() debug_api.end_debug()
...@@ -111,22 +106,22 @@ def test_per_tensor_scaling(configs_dir, feature_dirs): ...@@ -111,22 +106,22 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
# check modify_tensor_enabled # check modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled( assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0 "decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
) )[0]
assert debug_api.transformer_engine.modify_tensor_enabled( assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0 "decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0
) )[0]
assert debug_api.transformer_engine.modify_tensor_enabled( assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0 "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
) )[0]
assert not debug_api.transformer_engine.modify_tensor_enabled( assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0 "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0
) )[0]
assert not debug_api.transformer_engine.modify_tensor_enabled( assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0 "decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0
) )[0]
assert not debug_api.transformer_engine.modify_tensor_enabled( assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0 "decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0
) )[0]
# check modify_tensor # check modify_tensor
...@@ -168,14 +163,14 @@ def test_per_tensor_scaling(configs_dir, feature_dirs): ...@@ -168,14 +163,14 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
gemm="wgrad", gemm="wgrad",
tensor_name="gradient", tensor_name="gradient",
iteration=0, iteration=0,
) )[0]
assert not debug_api.transformer_engine.modify_tensor_enabled( assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc4", "decoder.1.mlp.fc4",
gemm="fprop", gemm="fprop",
tensor_name="activation", tensor_name="activation",
iteration=0, iteration=0,
) )[0]
finally: finally:
debug_api.end_debug() debug_api.end_debug()
...@@ -191,11 +186,11 @@ def test_fake_quant(configs_dir, feature_dirs): ...@@ -191,11 +186,11 @@ def test_fake_quant(configs_dir, feature_dirs):
# modify_tensor_enabled # modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled( assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0 "decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
) )[0]
assert debug_api.transformer_engine.modify_tensor_enabled( assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0 "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
) )[0]
# modify_tensor # modify_tensor
debug_api.transformer_engine.modify_tensor( debug_api.transformer_engine.modify_tensor(
...@@ -218,11 +213,11 @@ def test_fake_quant(configs_dir, feature_dirs): ...@@ -218,11 +213,11 @@ def test_fake_quant(configs_dir, feature_dirs):
assert debug_api.transformer_engine.fp8_gemm_enabled( assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0 "decoder.1.fc2", gemm="wgrad", iteration=0
) )[0]
# caching # caching
assert debug_api.transformer_engine.fp8_gemm_enabled( assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0 "decoder.1.fc2", gemm="wgrad", iteration=0
) )[0]
finally: finally:
debug_api.end_debug() debug_api.end_debug()
...@@ -236,13 +231,12 @@ def test_statistics_collection(configs_dir, feature_dirs): ...@@ -236,13 +231,12 @@ def test_statistics_collection(configs_dir, feature_dirs):
) )
tensor = torch.randn((100, 100, 5)).cuda() tensor = torch.randn((100, 100, 5)).cuda()
tensor_fp8 = Float8Tensor( quantizer = Float8Quantizer(
data=tensor.to(torch.uint8).cuda(), scale=torch.full([1], 1.0).cuda(),
fp8_scale_inv=torch.full([1], 1.0).cuda(), amax=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
shape=tensor.shape,
dtype=torch.float32,
) )
tensor_fp8 = quantizer(tensor)
def log(): def log():
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
...@@ -260,54 +254,64 @@ def test_statistics_collection(configs_dir, feature_dirs): ...@@ -260,54 +254,64 @@ def test_statistics_collection(configs_dir, feature_dirs):
tensor_name="activation", tensor_name="activation",
iteration=200, iteration=200,
tp_group=None, tp_group=None,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
) )
stats = log() stats = log()
assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max() assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max()
assert not debug_api.transformer_engine.inspect_tensor_enabled( assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", iteration=201 "decoder.1.mlp.fc1", tensor_name="activation", iteration=201
) )[0]
assert not debug_api.transformer_engine.inspect_tensor_enabled( assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.2.mlp.fc1", tensor_name="activation", iteration=200 "decoder.2.mlp.fc1", tensor_name="activation", iteration=200
) )[0]
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", iteration=200 expected_underflows = (
((tensor_fp8._data == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5)
) )
expected_underflows = (tensor_fp8._data == 0).sum() * 100 / (100 * 100 * 5) assert debug_api.transformer_engine.inspect_tensor_enabled(
expected_overflows = (tensor_fp8._data == 126).sum() * 100 / (100 * 100 * 5) "decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
)[0]
# TE FP8 tensor stats -- # TE FP8 tensor stats --
assert debug_api.transformer_engine.inspect_tensor_postquantize_enabled( assert debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200 "decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
) )[0]
debug_api.transformer_engine.inspect_tensor_postquantize( debug_api.transformer_engine.inspect_tensor(
"decoder.1.mlp.fc1", "decoder.1.mlp.fc1",
tensor=tensor_fp8,
tensor_name="gradient", tensor_name="gradient",
iteration=200, iteration=200,
rowwise=True,
tp_group=None, tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
) )
stats = log() stats = log()
torch.testing.assert_close( torch.testing.assert_close(
stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows
) )
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", gemm="fprop", iteration=201 "decoder.1.mlp.fc1", tensor_name="activation", iteration=201
) )[0]
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.2.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200 "decoder.2.mlp.fc1", tensor_name="gradient", iteration=200
) )[0]
# Second config in same yaml # Second config in same yaml
tensor = torch.rand((100, 100, 5)) tensor = torch.rand((100, 100, 5))
debug_api.transformer_engine.inspect_tensor( debug_api.transformer_engine.inspect_tensor(
"decoder.6.mlp.fc1", "decoder.6.mlp.fc1",
tensor=tensor,
tensor_name="activation", tensor_name="activation",
iteration=200, iteration=200,
tp_group=None, tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
) )
stats = log() stats = log()
stats_names = [x[3] for x in stats.keys()] stats_names = [x[3] for x in stats.keys()]
...@@ -316,10 +320,13 @@ def test_statistics_collection(configs_dir, feature_dirs): ...@@ -316,10 +320,13 @@ def test_statistics_collection(configs_dir, feature_dirs):
debug_api.transformer_engine.inspect_tensor( debug_api.transformer_engine.inspect_tensor(
"decoder.7.mlp.fc1", "decoder.7.mlp.fc1",
tensor=tensor,
tensor_name="weight", tensor_name="weight",
iteration=200, iteration=200,
tp_group=None, tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
) )
stats = log() stats = log()
stats_names = [x[3] for x in stats.keys()] stats_names = [x[3] for x in stats.keys()]
...@@ -328,7 +335,7 @@ def test_statistics_collection(configs_dir, feature_dirs): ...@@ -328,7 +335,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
assert not debug_api.transformer_engine.inspect_tensor_enabled( assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.7.mlp.fc1", tensor_name="weight", iteration=201 "decoder.7.mlp.fc1", tensor_name="weight", iteration=201
) )[0]
assert_empty() assert_empty()
finally: finally:
...@@ -343,21 +350,16 @@ def test_statistics_multi_run(configs_dir, feature_dirs): ...@@ -343,21 +350,16 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
default_logging_enabled=False, default_logging_enabled=False,
) )
def feed(tensor, tensor_fp8): def feed(tensor, tensor_fp8, quantizer):
debug_api.transformer_engine.inspect_tensor( debug_api.transformer_engine.inspect_tensor(
"decoder.5.mlp.fc1", "decoder.5.mlp.fc1",
tensor=tensor, tensor=tensor,
tensor_name="activation", tensor_name="activation",
iteration=1, iteration=1,
tp_group=None, tp_group=None,
) quantizer=quantizer,
debug_api.transformer_engine.inspect_tensor_postquantize( rowwise_quantized_tensor=tensor_fp8,
"decoder.5.mlp.fc1", columnwise_quantized_tensor=tensor_fp8,
tensor=tensor_fp8,
tensor_name="activation",
iteration=1,
rowwise=True,
tp_group=None,
) )
def log_stats(): def log_stats():
...@@ -365,26 +367,26 @@ def test_statistics_multi_run(configs_dir, feature_dirs): ...@@ -365,26 +367,26 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
return STATS_BUFFERS.log_stats() return STATS_BUFFERS.log_stats()
def fp8_tensor(t): quantizer = Float8Quantizer(
return Float8Tensor( scale=torch.full([1], 1.0).cuda(),
data=t.to(torch.uint8).cuda(), amax=torch.full([1], 1.0).cuda(),
fp8_scale_inv=torch.ones([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
shape=t.shape,
dtype=torch.float32,
) )
def fp8_tensor(t):
return quantizer(t.cuda())
shape = [1024, 1024] shape = [1024, 1024]
tensors = [torch.randn(shape) for _ in range(2)] tensors = [torch.randn(shape) for _ in range(2)]
tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)] tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)]
feed(tensors[0], tensors_fp8[0]) feed(tensors[0], tensors_fp8[0], quantizer)
feed(tensors[1], tensors_fp8[1]) feed(tensors[1], tensors_fp8[1], quantizer)
stats1 = log_stats() stats1 = log_stats()
tensor2 = torch.cat((tensors[0], tensors[1])).cuda() tensor2 = torch.cat((tensors[0], tensors[1])).cuda()
fp8tensor2 = fp8_tensor(tensor2) fp8tensor2 = fp8_tensor(tensor2)
feed(tensor2, fp8tensor2) feed(tensor2, fp8tensor2, quantizer)
stats2 = log_stats() stats2 = log_stats()
assert len(stats1.keys()) > 0 assert len(stats1.keys()) > 0
......
test:
enabled: True
layers:
layer_name_regex_pattern: .*
transformer_engine:
LogTensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [cur_amax, dynamic_range, mean, std, l1_norm]
start_step: 1
freq: 3
LogFp8TensorStats:
enabled: True
tensors: activation
stats: [underflows%]
start_step: 1
freq: 5
\ No newline at end of file
test:
enabled: True
layers:
layer_name_regex_pattern: .*1
transformer_engine:
LogTensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [cur_amax, dynamic_range, mean, std, l1_norm]
start_step: 0
freq: 100000
\ No newline at end of file
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import nvdlfw_inspect.api as debug_api
import transformer_engine.debug
import transformer_engine.pytorch as te
import torch
import tempfile
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import RecipeState
import pytest
import contextlib
import os
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.debug.pytorch.debug_state import TEDebugState
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
LOG_QUANTIZED_CONFIG_BASE = """
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogFp8TensorStats:
enabled: True
stats: [
{stats}
]
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
recipes = [
"fp8_delayed_scaling",
"fp8_current_scaling",
"fp8_block_scaling",
"mxfp8",
]
bare_stats = [
"underflows%",
"scale_inv_min",
"scale_inv_max",
"mse",
]
all_stats = []
for r in recipes:
for stat in bare_stats:
for columnwise_postfix in ["", "_columnwise"]:
if (
r in ["fp8_current_scaling", "fp8_block_scaling"]
and torch.cuda.get_device_capability()[0] < 9
):
# hopper is needed for current-scaling, block-scaling
continue
if r == "mxfp8" and torch.cuda.get_device_capability()[0] < 10:
# blackwell is needed for mxfp8
continue
if (
r in ["fp8_delayed_scaling", "fp8_current_scaling"]
and columnwise_postfix == "_columnwise"
):
# columnwise stats are not supported for fp8_delayed_scaling and fp8_current_scaling
continue
all_stats.append(f"{r}_{stat}{columnwise_postfix}")
all_stats.append("fp8_delayed_scaling_overflows%") # only delayed-scaling supports overflows%
@contextlib.contextmanager
def debug_session(config_str: str, feature_dirs):
"""
Helper context manager that
1. writes the YAML `config_str` to a temporary file,
2. starts a debug session, and
3. yields the directory that contains the statistics log.
The session is closed automatically – even on exceptions – so every test
stays concise and leak-free.
"""
with tempfile.NamedTemporaryFile(
mode="w", delete=False
) as cfg_file, tempfile.TemporaryDirectory() as log_dir:
cfg_file.write(config_str)
cfg_file.flush()
debug_api.initialize(
config_file=cfg_file.name,
feature_dirs=feature_dirs,
log_dir=log_dir,
)
try:
yield log_dir
finally:
debug_api.end_debug()
def read_log(log_dir: str) -> str:
"""Return the content of the statistics log produced by `debug_session`."""
stat_path = os.path.join(
log_dir,
"nvdlfw_inspect_statistics_logs",
"nvdlfw_inspect_globalrank-0.log",
)
with open(stat_path, "r") as f:
return f.read()
def test_sanity(feature_dirs):
log_all_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(all_stats))
with debug_session(log_all_stats_config, feature_dirs) as log_dir:
model = te.Linear(128, 128, params_dtype=torch.bfloat16)
inp = torch.zeros(128, 128, dtype=torch.bfloat16).cuda()
for _ in range(10):
with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()):
output = model(inp)
loss = output.sum()
loss.backward()
debug_api.step()
output = read_log(log_dir)
assert output, "Output is empty"
for stat in all_stats:
assert stat in output, f"Stat {stat} not found in output"
fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
]
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_numerics(fp8_recipe, feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling():
pytest.skip(reason_for_no_mxfp8)
if not fp8_block_scaling_available and fp8_recipe == recipe.Float8BlockScaling():
pytest.skip(reason_for_no_fp8_block_scaling)
log_only_bare_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(bare_stats))
with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir:
recipe_state = RecipeState.create(
fp8_recipe,
mode="forward",
num_quantizers=3,
)
tensor = torch.zeros(1024, 1024).cuda()
tensor[0, :] = 1000
quantizer = recipe_state.make_quantizers()[0]
quantized_tensor = quantizer(tensor)
debug_api.transformer_engine.inspect_tensor(
layer_name="layer_name",
tensor_name="activation",
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=quantized_tensor,
columnwise_quantized_tensor=quantized_tensor,
)
debug_api.step()
dequantized_tensor = quantized_tensor.dequantize()
output = read_log(log_dir)
for line in output.splitlines():
if "underflows%" in line:
underflows = float(line.split("value=")[1])
expected = (
((dequantized_tensor == 0).sum() - (tensor == 0).sum())
/ dequantized_tensor.numel()
* 100
)
assert underflows == pytest.approx(expected.cpu(), abs=1e-4)
if "mse" in line:
mse = float(line.split("value=")[1])
expected = torch.nn.functional.mse_loss(dequantized_tensor, tensor, reduction="mean")
assert mse == pytest.approx(expected.cpu(), abs=1e-6)
if "overflows%" in line:
overflows = float(line.split("value=")[1])
expected = (
(abs(dequantized_tensor) > abs(tensor)).sum() / dequantized_tensor.numel() * 100
)
assert overflows == pytest.approx(expected.cpu(), abs=1e-4)
@pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
# If layer does not invoke any feature in current iteration,
# then it changed into non-debug mode.
# This test checks whether this works correctly -
# non-quantized statistics should be logged every 3 iterations,
# and quantized statistics should be logged every 5 iterations.
with tempfile.TemporaryDirectory() as temp_dir:
debug_api.initialize(
config_file=configs_dir + "/log_config.yaml",
feature_dirs=feature_dirs,
log_dir=temp_dir,
)
if layer == "linear":
model = te.Linear(128, 128, name="linear1")
elif layer == "transformer":
model = te.TransformerLayer(128, 128, 4, name="transformer1")
else:
raise ValueError(f"Invalid layer: {layer}")
for i in range(20):
x = torch.randn(4, 128, 128).cuda()
with te.fp8_autocast(enabled=True):
y = model(x)
y.sum().backward()
debug_api.step()
with open(
os.path.join(
temp_dir, "nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log"
),
"r",
) as f:
file_content = f.read()
for i in range(1, 20):
if i % 3 == 0 or i % 5 == 0:
assert f"iteration={i:06d}" in file_content
else:
assert f"iteration={i:06d}" not in file_content
debug_api.end_debug()
TEDebugState._reset()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
import time
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug.pytorch.debug_state import TEDebugState
def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs):
debug_api.end_debug()
TEDebugState._reset()
if debug_tools_initialized:
# This config log stats starting from 0, every N iterations for huge N >> NUM_ITERS.
# So after 1 warm-up iteration, this layers should work in non-debug mode.
debug_api.initialize(
config_file=configs_dir + "/perf_config.yaml", feature_dirs=feature_dirs
)
try:
if layer == "linear":
model = torch.nn.Sequential(
te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2")
).cuda()
NUM_ITERS = 18000
elif layer == "transformer":
model = torch.nn.Sequential(
te.TransformerLayer(1, 1, 1, name="transformer1"),
te.TransformerLayer(1, 1, 1, name="transformer2"),
).cuda()
NUM_ITERS = 2000
x = torch.randn(1, 1, 1).cuda()
y = model(x)
y.sum().backward()
debug_api.step()
torch.cuda.synchronize()
time_start = time.time()
for i in range(NUM_ITERS):
y = model(x)
y.sum().backward()
if debug_tools_initialized:
debug_api.step()
torch.cuda.synchronize()
time_end = time.time()
finally:
if debug_tools_initialized:
debug_api.end_debug()
return time_end - time_start
@pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_cpu_overhead(layer, configs_dir, feature_dirs):
# runs one layer many times on very small tensor
# - gpu time should be negligible, so time should be dominated by cpu time.
# if layers does not invoke any feature in current iteration,
# then it changed into non-debug mode and should not have any non-negligible cpu overhead
# compared to layer without debug tools initialized.
with_debug_tools = _run_cpu_overhead(True, layer, configs_dir, feature_dirs)
without_debug_tools = _run_cpu_overhead(False, layer, configs_dir, feature_dirs)
print(f"with_debug_tools: {with_debug_tools} s")
print(f"without_debug_tools: {without_debug_tools} s")
assert with_debug_tools < without_debug_tools * 1.25 # 25% overhead margin
...@@ -519,6 +519,7 @@ def _train(opts): ...@@ -519,6 +519,7 @@ def _train(opts):
if opts.use_cuda_graphs: if opts.use_cuda_graphs:
del test_graph del test_graph
torch.cuda.synchronize()
te.module.base.destroy_ub() te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True) dist_print("Destroying Userbuffers objects...", debug=True)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pathlib
import sys
import pytest
import torch
import transformer_engine
from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention
from transformer_engine.pytorch import TransformerLayer, Linear
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig
model_configs = {
"small": ModelConfig(2, 10, 2, 16),
}
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention", "Linear"])
def test_current_device(model, module):
"""Test cases where current device is different from tensor device"""
num_devices = torch.cuda.device_count()
assert num_devices > 1, "This test requires more than one GPU!"
tensor_device = num_devices - 1
dtype = torch.bfloat16
config = model_configs[model]
args = []
kwargs = {}
bwd_args = []
if module == "TransformerLayer":
model = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_heads,
params_dtype=dtype,
attn_input_format="thd",
self_attn_mask_type="padding",
device=f"cuda:{tensor_device}",
)
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item()
args = [
torch.randn(
(num_tokens, config.hidden_size),
dtype=dtype,
device=f"cuda:{tensor_device}",
requires_grad=True,
)
]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv
if module == "DotProductAttention":
model = DotProductAttention(
config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding"
)
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item()
args = [
torch.randn(
num_tokens,
config.num_heads,
config.head_dim_qk,
dtype=dtype,
device=tensor_device,
requires_grad=True,
)
for _ in range(3)
]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv
bwd_args = [torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=tensor_device)]
elif module == "Linear":
model = Linear(
config.hidden_size,
4 * config.hidden_size,
params_dtype=dtype,
device=f"cuda:{tensor_device}",
)
args = [
torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device=f"cuda:{tensor_device}",
requires_grad=True,
)
]
current_device_before = torch.cuda.current_device()
out = model(*args, **kwargs)
if module == "DotProductAttention":
out.backward(*bwd_args)
else:
loss = out.sum()
loss.backward()
current_device_after = torch.cuda.current_device()
tensor_device_out = out.get_device()
tensor_device_grad = args[0].grad.get_device()
assert (
current_device_after == current_device_before
), "The current device should not have changed!"
assert (
tensor_device_out == tensor_device
), "The output tensor should be the same as the input tensors!"
assert (
tensor_device_grad == tensor_device
), "The gradient tensor should be the same as the input tensors!"
This diff is collapsed.
This diff is collapsed.
...@@ -223,7 +223,7 @@ class TestFloat8BlockwiseTensor: ...@@ -223,7 +223,7 @@ class TestFloat8BlockwiseTensor:
rowwise=True, rowwise=True,
columnwise=dq_columnwise, columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim, block_scaling_dim=block_scaling_dim,
all_gather_usage=True, all_gather_usage=(block_scaling_dim == 1),
) )
self._test_quantize_dequantize( self._test_quantize_dequantize(
quantizer=quantizer, quantizer=quantizer,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment