Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
...@@ -23,7 +23,7 @@ ALL_DISPATCH_COMBINE_CASES = [ ...@@ -23,7 +23,7 @@ ALL_DISPATCH_COMBINE_CASES = [
(128, 5, 128, 3), (128, 5, 128, 3),
(1024, 8, 128, 8), (1024, 8, 128, 8),
(4096, 32, 1280, 2), (4096, 32, 1280, 2),
(4096, 256, 4096, 6), (4096, 64, 4096, 6),
] ]
DISPATCH_COMBINE_CASES = { DISPATCH_COMBINE_CASES = {
"L0": ALL_DISPATCH_COMBINE_CASES[0:2], "L0": ALL_DISPATCH_COMBINE_CASES[0:2],
...@@ -44,7 +44,7 @@ ALL_DISPATCH_COMBINE_PADDING_CASES = [ ...@@ -44,7 +44,7 @@ ALL_DISPATCH_COMBINE_PADDING_CASES = [
(128, 5, 128, 3, 8), (128, 5, 128, 3, 8),
(1024, 8, 128, 8, 16), (1024, 8, 128, 8, 16),
(4096, 32, 1280, 2, 128), (4096, 32, 1280, 2, 128),
(4096, 256, 4096, 6, 16), (4096, 64, 4096, 6, 16),
] ]
DISPATCH_COMBINE_PADDING_CASES = { DISPATCH_COMBINE_PADDING_CASES = {
"L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:2], "L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:2],
......
...@@ -74,6 +74,14 @@ if not IS_HIP_EXTENSION: ...@@ -74,6 +74,14 @@ if not IS_HIP_EXTENSION:
f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}" f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}"
) )
# Get determinism
_deterministic = (
not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
or torch.are_deterministic_algorithms_enabled()
)
# Reset RNG seed and states # Reset RNG seed and states
seed = 1234 seed = 1234
reset_rng_states() reset_rng_states()
...@@ -147,6 +155,7 @@ def test_dot_product_attention( ...@@ -147,6 +155,7 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa: if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2] config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
if qkv_format == "thd" and "padding" not in config.attn_mask_type: if qkv_format == "thd" and "padding" not in config.attn_mask_type:
...@@ -162,8 +171,10 @@ def test_dot_product_attention( ...@@ -162,8 +171,10 @@ def test_dot_product_attention(
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
is_training=is_training, is_training=is_training,
deterministic=_deterministic,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported: if not fused_attn_supported:
is_training = False is_training = False
available_backends, _, fused_attn_backends = get_available_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
...@@ -172,6 +183,7 @@ def test_dot_product_attention( ...@@ -172,6 +183,7 @@ def test_dot_product_attention(
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
is_training=is_training, is_training=is_training,
deterministic=_deterministic,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
...@@ -421,6 +433,15 @@ def test_dpa_softmax(dtype, model_configs, model): ...@@ -421,6 +433,15 @@ def test_dpa_softmax(dtype, model_configs, model):
) )
@pytest.mark.skipif(get_cudnn_version() < (9, 18, 0), reason="cuDNN 9.18.0+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax_thd(dtype, model_configs, model):
"""Test DotProductAttention module with different softmax types"""
test_dot_product_attention(dtype, model_configs, model, True, True, "thd_thd_thd", False, False)
model_configs_mla = { model_configs_mla = {
#TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v #TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
...@@ -685,9 +706,10 @@ model_configs_swa = { ...@@ -685,9 +706,10 @@ model_configs_swa = {
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa]) @pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys()) @pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model): @pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "sbhd_sbhd_sbhd"])
def test_dpa_sliding_window(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with sliding window attention""" """Test DotProductAttention module with sliding window attention"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False) test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False)
model_configs_alibi_slopes = { model_configs_alibi_slopes = {
...@@ -889,11 +911,14 @@ def _run_dot_product_attention( ...@@ -889,11 +911,14 @@ def _run_dot_product_attention(
reset_rng_states() reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention": if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
# Create seqlens # Create seqlens
...@@ -1295,6 +1320,7 @@ def test_transformer_layer( ...@@ -1295,6 +1320,7 @@ def test_transformer_layer(
qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd") qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
), ),
is_training=is_training, is_training=is_training,
deterministic=_deterministic,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported: if not fused_attn_supported:
...@@ -1308,6 +1334,7 @@ def test_transformer_layer( ...@@ -1308,6 +1334,7 @@ def test_transformer_layer(
else qkv_format.replace("hd", "3hd") else qkv_format.replace("hd", "3hd")
), ),
is_training=is_training, is_training=is_training,
deterministic=_deterministic,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
...@@ -1435,10 +1462,13 @@ def _run_transformer_layer( ...@@ -1435,10 +1462,13 @@ def _run_transformer_layer(
reset_rng_states() reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention": if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
# Create input tensor # Create input tensor
...@@ -1632,6 +1662,7 @@ def test_dpa_fp8_extra_state(model, dtype): ...@@ -1632,6 +1662,7 @@ def test_dpa_fp8_extra_state(model, dtype):
qkv_dtype=torch.float8_e4m3fn, qkv_dtype=torch.float8_e4m3fn,
qkv_layout="sb3hd", qkv_layout="sb3hd",
is_training=is_training, is_training=is_training,
deterministic=_deterministic,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported and not flash_attn_supported: if not fused_attn_supported and not flash_attn_supported:
...@@ -1822,6 +1853,7 @@ def test_mha_fp8_vs_f16( ...@@ -1822,6 +1853,7 @@ def test_mha_fp8_vs_f16(
fp8=True, fp8=True,
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
is_training=is_training, is_training=is_training,
deterministic=_deterministic,
) )
flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported_fp8 < 1: if flash_attn_supported + fused_attn_supported_fp8 < 1:
...@@ -1833,6 +1865,7 @@ def test_mha_fp8_vs_f16( ...@@ -1833,6 +1865,7 @@ def test_mha_fp8_vs_f16(
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_format.replace("hd", "h3d"), qkv_layout=qkv_format.replace("hd", "h3d"),
is_training=is_training, is_training=is_training,
deterministic=_deterministic,
) )
_, fused_attn_supported_f16, _ = available_backends _, fused_attn_supported_f16, _ = available_backends
if not fused_attn_supported_f16: if not fused_attn_supported_f16:
...@@ -1841,6 +1874,7 @@ def test_mha_fp8_vs_f16( ...@@ -1841,6 +1874,7 @@ def test_mha_fp8_vs_f16(
if flash_attn_supported: if flash_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
...@@ -1850,6 +1884,7 @@ def test_mha_fp8_vs_f16( ...@@ -1850,6 +1884,7 @@ def test_mha_fp8_vs_f16(
if fused_attn_supported_fp8: if fused_attn_supported_fp8:
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
...@@ -1859,6 +1894,7 @@ def test_mha_fp8_vs_f16( ...@@ -1859,6 +1894,7 @@ def test_mha_fp8_vs_f16(
if fused_attn_supported_f16: if fused_attn_supported_f16:
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
...@@ -2071,6 +2107,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal ...@@ -2071,6 +2107,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
fp8=True, fp8=True,
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
is_training=is_training, is_training=is_training,
deterministic=_deterministic,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported < 1: if flash_attn_supported + fused_attn_supported < 1:
...@@ -2081,6 +2118,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal ...@@ -2081,6 +2118,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
is_training=is_training, is_training=is_training,
deterministic=_deterministic,
) )
_, fused_attn_supported, _ = available_backends _, fused_attn_supported, _ = available_backends
if not fused_attn_supported: if not fused_attn_supported:
...@@ -2091,6 +2129,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal ...@@ -2091,6 +2129,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
if flash_attn_supported: if flash_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)") logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)")
flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
...@@ -2100,6 +2139,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal ...@@ -2100,6 +2139,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
if unfused_attn_supported: if unfused_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)")
unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
...@@ -2108,6 +2148,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal ...@@ -2108,6 +2148,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
...@@ -2116,6 +2157,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal ...@@ -2116,6 +2157,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if config.dropout_p == 0.0: if config.dropout_p == 0.0:
# test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)")
...@@ -2370,13 +2412,16 @@ def test_custom_mha_fp8_vs_f16(dtype, model): ...@@ -2370,13 +2412,16 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
qkv_dtype=torch.float8_e4m3fn, qkv_dtype=torch.float8_e4m3fn,
qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd", qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
is_training=is_training, is_training=is_training,
deterministic=_deterministic,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not (fused_attn_backends and unfused_attn_supported): if not (fused_attn_backends and unfused_attn_supported):
pytest.skip("Not enough backends to run this test with.") pytest.skip("Not enough backends to run this test with.")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(
dtype, config, "UnfusedDotProductAttention"
)
atol = 5e-1 atol = 5e-1
rtol = 5e-1 rtol = 5e-1
...@@ -2409,10 +2454,13 @@ def _run_custom_mha_fp8(dtype, config, backend): ...@@ -2409,10 +2454,13 @@ def _run_custom_mha_fp8(dtype, config, backend):
reset_rng_states() reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention": if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
inp = 0.0001 * torch.randint( inp = 0.0001 * torch.randint(
...@@ -2463,10 +2511,13 @@ def _run_ref_mha_f16(dtype, config, backend): ...@@ -2463,10 +2511,13 @@ def _run_ref_mha_f16(dtype, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention": if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
inp = torch.load("qkv.pt").to(device="cuda") inp = torch.load("qkv.pt").to(device="cuda")
...@@ -2754,7 +2805,7 @@ class Custom_MHA_FP8(TransformerEngineBaseModule): ...@@ -2754,7 +2805,7 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
cu_seqlens, cu_seqlens,
max_s, max_s,
) -> torch.Tensor: ) -> torch.Tensor:
with self.prepare_forward(inp, num_gemms=3) as inp: with self.prepare_forward_ctx(inp, num_gemms=3) as inp:
out = _custom_mha_fp8.apply( out = _custom_mha_fp8.apply(
inp, inp,
self.qkv_weight, self.qkv_weight,
......
...@@ -148,7 +148,7 @@ model_configs_fused_attn = { ...@@ -148,7 +148,7 @@ model_configs_fused_attn = {
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), # MHA ), # MHA
"cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA
"cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA
"cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
"cp_2_2": ModelConfig( "cp_2_2": ModelConfig(
...@@ -164,7 +164,7 @@ model_configs_fused_attn = { ...@@ -164,7 +164,7 @@ model_configs_fused_attn = {
2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias" 2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias"
), # GQA ), # GQA
"cp_2_4": ModelConfig( "cp_2_4": ModelConfig(
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512)
), # GQA ), # GQA
"cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA
"cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA
...@@ -188,7 +188,16 @@ dtypes = ["bf16", "fp16", "fp8"] ...@@ -188,7 +188,16 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats = ["bshd", "sbhd", "thd"] qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential: if test_essential:
configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] configs = [
"cp_1_0",
"cp_1_1",
"cp_1_4",
"cp_2_0",
"cp_2_2",
"cp_2_4",
"cp_3_2",
"cp_4_2",
]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"] dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"] qkv_formats = ["sbhd", "thd"]
...@@ -284,9 +293,14 @@ def test_cp_with_fused_attention( ...@@ -284,9 +293,14 @@ def test_cp_with_fused_attention(
pytest.skip( pytest.skip(
"CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!"
) )
if config.softmax_type != "vanilla" and qkv_format == "thd": if (
get_cudnn_version() < (9, 18, 0)
and config.softmax_type != "vanilla"
and qkv_format == "thd"
):
pytest.skip( pytest.skip(
"CP implementation does not support qkv_format=thd for non-vanilla softmax types!" "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for"
" non-vanilla softmax types!"
) )
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
......
...@@ -15,6 +15,7 @@ from transformer_engine.pytorch import ( ...@@ -15,6 +15,7 @@ from transformer_engine.pytorch import (
is_fp8_available, is_fp8_available,
is_mxfp8_available, is_mxfp8_available,
is_fp8_block_scaling_available, is_fp8_block_scaling_available,
is_nvfp4_available,
) )
from transformer_engine.pytorch.quantization import RecipeState from transformer_engine.pytorch.quantization import RecipeState
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
...@@ -29,6 +30,7 @@ mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) ...@@ -29,6 +30,7 @@ mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available( fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available(
return_reason=True return_reason=True
) )
nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True)
LOG_QUANTIZED_CONFIG_BASE = """ LOG_QUANTIZED_CONFIG_BASE = """
log: log:
...@@ -363,6 +365,124 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): ...@@ -363,6 +365,124 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
TEDebugState._reset() TEDebugState._reset()
# NVFP4 tests
LOG_NVFP4_CONFIG_BASE = """
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogNvfp4TensorStats:
enabled: True
stats: [
{stats}
]
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
def test_nvfp4_numeric(feature_dirs):
"""Test that NVFP4 underflows% and MSE stats are computed correctly with known values."""
if not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
log_nvfp4_config = LOG_NVFP4_CONFIG_BASE.format(stats="underflows%, mse")
with debug_session(log_nvfp4_config, feature_dirs) as log_dir:
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.quantization import RecipeState
recipe_state = RecipeState.create(
recipe.NVFP4BlockScaling(),
mode="forward",
num_quantizers=3,
)
# Create test tensor with known distribution
torch.manual_seed(42)
tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda()
# Add some small values that should underflow to zero in FP4
tensor[0, :16] = 0.0001
quantizer = recipe_state.make_quantizers()[0]
quantized_tensor = quantizer(tensor)
debug_api.transformer_engine.inspect_tensor(
layer_name="test_layer",
tensor_name="activation",
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=quantized_tensor,
columnwise_quantized_tensor=quantized_tensor,
)
debug_api.step()
dequantized_tensor = quantized_tensor.dequantize()
output = read_log(log_dir)
# Validate both stats are present
assert "nvfp4_underflows%" in output, "underflows% stat missing"
assert "nvfp4_mse" in output, "mse stat missing"
# Extract values and validate numerics
underflows_value = None
mse_value = None
for line in output.splitlines():
if "nvfp4_underflows%" in line and "value=" in line:
underflows_value = float(line.split("value=")[1].split()[0])
if "nvfp4_mse" in line and "value=" in line:
mse_value = float(line.split("value=")[1].split()[0])
# Compute expected underflows: non-zero elements that became zero after quantization
orig_nonzero_mask = tensor != 0
dequant_zero_mask = dequantized_tensor == 0
expected_underflows = (
(orig_nonzero_mask & dequant_zero_mask).sum().float() / tensor.numel() * 100
)
# Allow some tolerance
assert underflows_value == pytest.approx(expected_underflows.cpu().item(), abs=1e-4)
# Compute expected MSE
expected_mse = torch.nn.functional.mse_loss(
dequantized_tensor.float(), tensor.float(), reduction="mean"
)
assert mse_value == pytest.approx(expected_mse.cpu().item(), abs=1e-4)
def test_fp8_stats_allows_nvfp4_with_recipe_prefix(feature_dirs):
"""Test that LogFp8TensorStats allows recipe-prefixed stats with NVFP4 for what-if analysis."""
if not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
# Use recipe-prefixed stat with NVFP4 - should work (computes MXFP8 separately)
log_fp8_config = LOG_QUANTIZED_CONFIG_BASE.format(stats="mxfp8_mse")
with debug_session(log_fp8_config, feature_dirs) as log_dir:
model = te.Linear(128, 128, params_dtype=torch.bfloat16)
inp = torch.randn(128, 128, dtype=torch.bfloat16).cuda()
# Should work - recipe-prefixed stats compute MXFP8 separately for comparison
for _ in range(2):
with te.autocast(recipe=recipe.NVFP4BlockScaling()):
output = model(inp)
loss = output.sum()
loss.backward()
debug_api.step()
output = read_log(log_dir)
# Should have logged MXFP8 MSE stat (what-if scenario)
assert "mxfp8_mse" in output
def test_log_grouped_gemm(feature_dirs): def test_log_grouped_gemm(feature_dirs):
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
......
...@@ -30,10 +30,17 @@ configs = { ...@@ -30,10 +30,17 @@ configs = {
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range] stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0 start_step : 0
end_step: 1 end_step: 1
""",
"log_fp8": """log_fp8:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogFp8TensorStats: LogFp8TensorStats:
enabled: True enabled: True
tensors: [activation, gradient, weight] tensors: [activation, gradient, weight]
stats: [underflows, overflows] stats: [underflows%]
start_step : 0 start_step : 0
end_step: 1 end_step: 1
""", """,
...@@ -46,22 +53,26 @@ fake_quant_config: ...@@ -46,22 +53,26 @@ fake_quant_config:
FakeQuant: FakeQuant:
enabled: True enabled: True
gemms: [fprop, dgrad, wgrad] gemms: [fprop, dgrad, wgrad]
tensors: [activation, weight, gradient]
quant_format: FP8E5M2 quant_format: FP8E5M2
""", """,
} }
# Configs that require FP8 to be enabled
fp8_required_configs = {"log_fp8"}
def _get_model(model_key): def _get_model(model_key):
if model_key == "linear": if model_key == "linear":
return te.Linear(D, D) return te.Linear(D, D, name="layer")
if model_key == "layernorm_linear": if model_key == "layernorm_linear":
return te.LayerNormLinear(D, D) return te.LayerNormLinear(D, D, name="layer")
if model_key == "layernorm_mlp": if model_key == "layernorm_mlp":
return te.LayerNormMLP(D, D, D) return te.LayerNormMLP(D, D, D, name="layer")
if model_key == "mha_attention": if model_key == "mha_attention":
return te.MultiheadAttention(D, H) return te.MultiheadAttention(D, H, name="layer")
if model_key == "transformer_layer": if model_key == "transformer_layer":
return te.TransformerLayer(D, D, H) return te.TransformerLayer(D, D, H, name="layer")
def _run_forward_backward(model, fp8): def _run_forward_backward(model, fp8):
...@@ -95,4 +106,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir): ...@@ -95,4 +106,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
def test_sanity_debug(model_key, fp8, config_key, feature_dirs): def test_sanity_debug(model_key, fp8, config_key, feature_dirs):
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if not fp8 and config_key in fp8_required_configs:
pytest.skip(f"Config '{config_key}' requires FP8")
_run_test(model_key, fp8, configs[config_key], feature_dirs) _run_test(model_key, fp8, configs[config_key], feature_dirs)
...@@ -101,7 +101,7 @@ class TestLoadCheckpoint: ...@@ -101,7 +101,7 @@ class TestLoadCheckpoint:
# Path to save checkpoint # Path to save checkpoint
if checkpoint_dir is None: if checkpoint_dir is None:
checkpoint_dir = TestLoadCheckpoint._checkpoint_dir() checkpoint_dir = TestLoadCheckpoint._checkpoint_dir()
checkpoint_dir.mkdir(exist_ok=True) checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint_file = checkpoint_dir / f"{name}.pt" checkpoint_file = checkpoint_dir / f"{name}.pt"
# Create module and save checkpoint # Create module and save checkpoint
......
...@@ -5,8 +5,10 @@ ...@@ -5,8 +5,10 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
import functools
import io import io
import math import math
import random
from typing import Optional from typing import Optional
import pytest import pytest
...@@ -37,7 +39,14 @@ from transformer_engine.pytorch import ( ...@@ -37,7 +39,14 @@ from transformer_engine.pytorch import (
import transformer_engine_torch as tex import transformer_engine_torch as tex
# Import utility functions # Import utility functions
from utils import dtype_tols, make_recipe, quantization_tols, reset_rng_states from utils import (
assert_close,
assert_close_grads,
dtype_tols,
make_recipe,
quantization_tols,
reset_rng_states,
)
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
import os import os
...@@ -116,6 +125,9 @@ def maybe_skip_quantization( ...@@ -116,6 +125,9 @@ def maybe_skip_quantization(
@torch.no_grad() @torch.no_grad()
def make_reference_and_test_tensors( def make_reference_and_test_tensors(
shape: int | Iterable[int], shape: int | Iterable[int],
*,
min: float = 0.0,
max: float = 1.0,
quantization: Optional[str] = None, quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64, ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu", ref_device: torch.device = "cpu",
...@@ -136,7 +148,8 @@ def make_reference_and_test_tensors( ...@@ -136,7 +148,8 @@ def make_reference_and_test_tensors(
""" """
# Random reference tensor # Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) ref = torch.empty(shape, dtype=ref_dtype, device=ref_device)
ref.uniform_(min, max)
# Construct test tensor from reference tensor # Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype) test = ref.to(device=test_device, dtype=test_dtype)
...@@ -1569,7 +1582,19 @@ class TestBasicOps: ...@@ -1569,7 +1582,19 @@ class TestBasicOps:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"activation", "activation",
("gelu", "geglu", "qgelu", "qgeglu", "relu", "reglu", "srelu", "sreglu", "silu", "swiglu"), (
"gelu",
"geglu",
"qgelu",
"qgeglu",
"relu",
"reglu",
"glu",
"srelu",
"sreglu",
"silu",
"swiglu",
),
) )
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
...@@ -1589,7 +1614,7 @@ class TestBasicOps: ...@@ -1589,7 +1614,7 @@ class TestBasicOps:
# Tensor dimensions # Tensor dimensions
in_shape = list(out_shape) in_shape = list(out_shape)
if activation in ("geglu", "qgeglu", "reglu", "sreglu", "swiglu"): if activation in ("geglu", "glu", "qgeglu", "reglu", "sreglu", "swiglu"):
in_shape[-1] *= 2 in_shape[-1] *= 2
# Skip invalid configurations # Skip invalid configurations
...@@ -1629,6 +1654,13 @@ class TestBasicOps: ...@@ -1629,6 +1654,13 @@ class TestBasicOps:
elif activation == "reglu": elif activation == "reglu":
x1, x2 = x_ref.chunk(2, dim=-1) x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.relu(x1) * x2 y_ref = torch.nn.functional.relu(x1) * x2
elif activation == "sigmoid":
y_ref = torch.nn.functional.sigmoid(x_ref)
elif activation == "glu":
x = x_ref.reshape(*in_shape[:-1], 2, in_shape[-1] // 2)
x = x.flip(-2) # PyTorch GLU swaps gate and linear unit
x = x.reshape(in_shape)
y_ref = torch.nn.functional.glu(x)
elif activation == "srelu": elif activation == "srelu":
y_ref = torch.nn.functional.relu(x_ref) ** 2 y_ref = torch.nn.functional.relu(x_ref) ** 2
elif activation == "sreglu": elif activation == "sreglu":
...@@ -1648,6 +1680,7 @@ class TestBasicOps: ...@@ -1648,6 +1680,7 @@ class TestBasicOps:
make_op = dict( make_op = dict(
gelu=te_ops.GELU, gelu=te_ops.GELU,
geglu=te_ops.GEGLU, geglu=te_ops.GEGLU,
glu=te_ops.GLU,
qgelu=te_ops.QGELU, qgelu=te_ops.QGELU,
qgeglu=te_ops.QGEGLU, qgeglu=te_ops.QGEGLU,
relu=te_ops.ReLU, relu=te_ops.ReLU,
...@@ -1692,6 +1725,7 @@ class TestBasicOps: ...@@ -1692,6 +1725,7 @@ class TestBasicOps:
quantization: Optional[str], quantization: Optional[str],
quantize_forward: bool, quantize_forward: bool,
quantize_backward: bool, quantize_backward: bool,
glu_interleave_size: Optional[int] = None,
): ):
# Tensor dimensions # Tensor dimensions
...@@ -1718,7 +1752,17 @@ class TestBasicOps: ...@@ -1718,7 +1752,17 @@ class TestBasicOps:
) )
# Plain PyTorch implementation # Plain PyTorch implementation
x1, x2 = x_ref.chunk(2, dim=-1) x = x_ref
if glu_interleave_size is not None:
x = x.reshape(
*in_shape[:-1],
in_shape[-1] // (2 * glu_interleave_size),
2,
glu_interleave_size,
)
x = x.transpose(-3, -2)
x = x.reshape(in_shape)
x1, x2 = x.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2 y_ref = torch.nn.functional.silu(x1) * x2
y_ref.backward(dy_ref) y_ref.backward(dy_ref)
...@@ -1726,7 +1770,7 @@ class TestBasicOps: ...@@ -1726,7 +1770,7 @@ class TestBasicOps:
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
forward = te_ops.Sequential( forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantize_backward), te_ops.Quantize(forward=False, backward=quantize_backward),
te_ops.SwiGLU(), te_ops.SwiGLU(glu_interleave_size=glu_interleave_size),
te_ops.Quantize(forward=quantize_forward, backward=False), te_ops.Quantize(forward=quantize_forward, backward=False),
) )
with te.autocast(enabled=quantized_compute, recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
...@@ -1739,10 +1783,19 @@ class TestBasicOps: ...@@ -1739,10 +1783,19 @@ class TestBasicOps:
tols = quantization_tols(quantization) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") assert_close(y_test, y_ref, **tols)
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") assert_close_grads(x_test, x_ref, **tols)
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols) def test_interleaved_swiglu(self):
"""SwiGLU with block interleaved input format"""
self.test_swiglu(
out_shape=(32, 192),
dtype=torch.float32,
quantization=None,
quantize_forward=False,
quantize_backward=False,
glu_interleave_size=32,
)
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantization", _quantization_list)
...@@ -1752,6 +1805,7 @@ class TestBasicOps: ...@@ -1752,6 +1805,7 @@ class TestBasicOps:
self, self,
*, *,
out_shape: Iterable[int] = (32, 32), out_shape: Iterable[int] = (32, 32),
glu_interleave_size: Optional[int] = None,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device = "cuda", device: torch.device = "cuda",
quantization: Optional[str], quantization: Optional[str],
...@@ -1760,7 +1814,7 @@ class TestBasicOps: ...@@ -1760,7 +1814,7 @@ class TestBasicOps:
limit: float = 0.75, limit: float = 0.75,
alpha: float = 1.702, alpha: float = 1.702,
): ):
# Test SwiGLU variant used in GPT OSS. """SwiGLU variant used in GPT-OSS"""
# Tensor dimensions # Tensor dimensions
in_shape = list(out_shape) in_shape = list(out_shape)
in_shape[-1] *= 2 in_shape[-1] *= 2
...@@ -1785,7 +1839,17 @@ class TestBasicOps: ...@@ -1785,7 +1839,17 @@ class TestBasicOps:
) )
# Plain PyTorch implementation # Plain PyTorch implementation
x_glu, x_linear = x_ref.chunk(2, dim=-1) x = x_ref
if glu_interleave_size is not None:
x = x.reshape(
*in_shape[:-1],
in_shape[-1] // (2 * glu_interleave_size),
2,
glu_interleave_size,
)
x = x.transpose(-3, -2)
x = x.reshape(in_shape)
x_glu, x_linear = x.chunk(2, dim=-1)
x_glu = x_glu.clamp(min=None, max=limit) x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit) x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu) out_glu = x_glu * torch.sigmoid(alpha * x_glu)
...@@ -1797,7 +1861,11 @@ class TestBasicOps: ...@@ -1797,7 +1861,11 @@ class TestBasicOps:
forward = te_ops.Sequential( forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantize_backward), te_ops.Quantize(forward=False, backward=quantize_backward),
te_ops.ClampedSwiGLU(limit=limit, alpha=alpha), te_ops.ClampedSwiGLU(
limit=limit,
alpha=alpha,
glu_interleave_size=glu_interleave_size,
),
te_ops.Quantize(forward=quantize_forward, backward=False), te_ops.Quantize(forward=quantize_forward, backward=False),
) )
with te.autocast(enabled=quantized_compute, recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
...@@ -1813,10 +1881,19 @@ class TestBasicOps: ...@@ -1813,10 +1881,19 @@ class TestBasicOps:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") assert_close(y_test, y_ref, **tols)
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") assert_close_grads(x_test, x_ref, **tols)
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols) def test_interleaved_clamped_swiglu(self):
"""GPT-OSS SwiGLU with block interleaved input format"""
self.test_clamped_swiglu(
out_shape=(32, 192),
dtype=torch.float32,
quantization=None,
quantize_forward=False,
quantize_backward=False,
glu_interleave_size=32,
)
@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5)) @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2))) @pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2)))
...@@ -1936,6 +2013,231 @@ class TestBasicOps: ...@@ -1936,6 +2013,231 @@ class TestBasicOps:
abs(z_score) < 2.5758 abs(z_score) < 2.5758
), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})" ), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})"
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("weight_requires_grad", (False, True))
def test_grouped_linear(
self,
*,
group_size: int = 4,
bias: bool,
weight_shape: tuple[int, int] = (128, 128),
split_alignment: int = 128,
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_compute: bool,
quantized_weight: bool,
input_requires_grad: bool,
weight_requires_grad: bool,
) -> None:
"""Grouped GEMM"""
# Split sizes
split_sizes = [split_alignment * i for i in range(group_size)]
random.shuffle(split_sizes)
split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device)
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = (split_sizes.sum().item(), in_features)
out_shape = (in_shape[0], out_features)
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if quantization is None and (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not specified")
if quantization is not None and not (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not used")
if quantization is not None and dtype not in (torch.bfloat16, torch.float16):
pytest.skip("Quantized group GEMM is only supported with BF16/FP16")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=input_requires_grad,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
ws_ref, ws_test = [], []
bs_ref, bs_test = [], []
for _ in range(group_size):
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=weight_requires_grad,
)
b_ref, b_test = None, None
if bias:
b_ref, b_test = make_reference_and_test_tensors(
out_features,
test_dtype=dtype,
test_device=device,
requires_grad=weight_requires_grad,
)
ws_ref.append(w_ref)
ws_test.append(w_test)
bs_ref.append(b_ref)
bs_test.append(b_test)
# Plain PyTorch implementation
xs_ref = torch.split(x_ref, split_sizes.tolist())
ys_ref = []
for x, w, b in zip(xs_ref, ws_ref, bs_ref):
ys_ref.append(torch.nn.functional.linear(x, w, bias=b))
y_ref = torch.cat(ys_ref)
if input_requires_grad or weight_requires_grad:
y_ref.backward(dy_ref)
# Construct fusible operation
recipe = make_recipe(quantization)
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.GroupedLinear(
group_size,
in_features,
out_features,
bias=bias,
device=device,
dtype=dtype,
)
with torch.no_grad():
for group_idx in range(group_size):
getattr(op, f"weight{group_idx}").copy_(ws_test[group_idx])
if bias:
getattr(op, f"bias{group_idx}").copy_(bs_test[group_idx])
del ws_test, bs_test
for param in op.parameters():
param.requires_grad_(requires_grad=weight_requires_grad)
# Forward and backward pass with op
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = op(x_test, split_sizes)
if input_requires_grad or weight_requires_grad:
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
if input_requires_grad:
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
else:
assert x_test.grad is None
for group_idx in range(group_size):
w_test = getattr(op, f"weight{group_idx}")
if weight_requires_grad:
dw_test = w_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dw_test, ws_ref[group_idx].grad, **tols)
else:
assert w_test.grad is None
if bias:
b_test = getattr(op, f"bias{group_idx}")
if weight_requires_grad:
db_test = b_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, bs_ref[group_idx].grad, **tols)
else:
assert b_test.grad is None
@pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128)))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("scales_requires_grad", (False, True))
def test_scaled_swiglu(
self,
*,
in_shape: Iterable[int],
glu_interleave_size: Optional[int] = None,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
input_requires_grad: bool,
scales_requires_grad: bool,
) -> None:
"""SwiGLU with post-scale"""
# Tensor dims
out_shape = list(in_shape)
out_shape[-1] //= 2
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=input_requires_grad,
)
scales_ref, scales_test = make_reference_and_test_tensors(
in_shape[:-1],
test_dtype=dtype,
test_device=device,
requires_grad=scales_requires_grad,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
x = x_ref
if glu_interleave_size is not None:
x = x.reshape(
-1,
in_shape[-1] // (2 * glu_interleave_size),
2,
glu_interleave_size,
)
x = x.transpose(1, 2)
x = x.reshape(in_shape)
x1, x2 = x.chunk(2, dim=-1)
y = torch.nn.functional.silu(x1) * x2
y_ref = scales_ref.unsqueeze(-1) * y
if input_requires_grad or scales_requires_grad:
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
y_test = op(x_test, scales_test)
if input_requires_grad or scales_requires_grad:
y_test.backward(dy_test)
# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
assert_close(y_test, y_ref, **tols)
assert_close_grads(x_test, x_ref, **tols)
assert_close_grads(scales_test, scales_ref, **tols)
def test_interleaved_scaled_swiglu(self):
"""SwiGLU with post-scale and block interleaved input format"""
self.test_scaled_swiglu(
in_shape=(32, 192),
glu_interleave_size=32,
input_requires_grad=True,
scales_requires_grad=True,
)
class TestFusedOps: class TestFusedOps:
"""Tests for fused operations""" """Tests for fused operations"""
...@@ -2342,13 +2644,13 @@ class TestFusedOps: ...@@ -2342,13 +2644,13 @@ class TestFusedOps:
backward_ops = model._module_groups[0]._backward_ops backward_ops = model._module_groups[0]._backward_ops
if with_quantization: if with_quantization:
assert len(backward_ops) == 2 assert len(backward_ops) == 2
assert isinstance(backward_ops[0][0], BackwardActivationBias) assert isinstance(backward_ops[0][0], te_ops.Quantize)
assert isinstance(backward_ops[1][0], te_ops.Quantize) assert isinstance(backward_ops[1][0], BackwardActivationBias)
else: else:
assert len(backward_ops) == 3 assert len(backward_ops) == 3
assert isinstance(backward_ops[0][0], act_type) assert isinstance(backward_ops[0][0], te_ops.Quantize)
assert isinstance(backward_ops[1][0], te_ops.Bias) assert isinstance(backward_ops[1][0], te_ops.Bias)
assert isinstance(backward_ops[2][0], te_ops.Quantize) assert isinstance(backward_ops[2][0], act_type)
# Expected numerical error # Expected numerical error
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
...@@ -2944,3 +3246,499 @@ class TestSequentialModules: ...@@ -2944,3 +3246,499 @@ class TestSequentialModules:
if bias: if bias:
torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols) torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols) torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols)
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("glu_interleave_size", (None, 32))
def test_grouped_mlp(
self,
*,
group_size: int = 4,
bias: bool,
hidden_size: int = 256,
dtype: torch.dtype,
quantization: Optional[str],
device: torch.device = "cuda",
split_alignment: int = 256,
glu_interleave_size: Optional[int],
) -> None:
"""GroupedLinear + ScaledSwiGLU + GroupedLinear"""
# Split sizes
split_sizes = [split_alignment * i for i in range(group_size)]
random.shuffle(split_sizes)
split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device)
# Make input shape
in_shape = (split_sizes.sum().item(), hidden_size)
out_shape = in_shape
# Skip invalid configurations
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
if with_quantization and dtype not in (torch.bfloat16, torch.float16):
pytest.skip("Quantized group GEMM is only supported with BF16/FP16")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
min=-0.25,
max=0.25,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
min=-0.25,
max=0.25,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
probs_ref, probs_test = make_reference_and_test_tensors(
(in_shape[0],),
test_dtype=dtype,
test_device=device,
)
fc1_ws_ref, fc1_ws_test = [], []
fc1_bs_ref, fc1_bs_test = [], []
fc2_ws_ref, fc2_ws_test = [], []
fc2_bs_ref, fc2_bs_test = [], []
for _ in range(group_size):
fc1_w_ref, fc1_w_test = make_reference_and_test_tensors(
(2 * hidden_size, hidden_size),
min=-0.25,
max=0.25,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
fc2_w_ref, fc2_w_test = make_reference_and_test_tensors(
(hidden_size, hidden_size),
min=-0.25,
max=0.25,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
fc1_b_ref, fc1_b_test = None, None
fc2_b_ref, fc2_b_test = None, None
if bias:
fc1_b_ref, fc1_b_test = make_reference_and_test_tensors(
(2 * hidden_size,),
min=-0.5,
max=0.5,
test_dtype=dtype,
test_device=device,
)
fc2_b_ref, fc2_b_test = make_reference_and_test_tensors(
(hidden_size,),
min=-0.5,
max=0.5,
test_dtype=dtype,
test_device=device,
)
fc1_ws_ref.append(fc1_w_ref)
fc1_bs_ref.append(fc1_b_ref)
fc1_ws_test.append(fc1_w_test)
fc1_bs_test.append(fc1_b_test)
fc2_ws_ref.append(fc2_w_ref)
fc2_bs_ref.append(fc2_b_ref)
fc2_ws_test.append(fc2_w_test)
fc2_bs_test.append(fc2_b_test)
# Reference implementation
xs = torch.split(x_ref, split_sizes.tolist())
probs = torch.split(probs_ref, split_sizes.tolist())
ys = []
for group_idx in range(group_size):
x = xs[group_idx]
x = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx])
if glu_interleave_size is not None:
x = x.reshape(
-1,
2 * hidden_size // (2 * glu_interleave_size),
2,
glu_interleave_size,
)
x = x.transpose(1, 2)
x = x.reshape(-1, 2 * hidden_size)
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
x = x * probs[group_idx].unsqueeze(-1)
x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx], bias=fc2_bs_ref[group_idx])
ys.append(x)
y_ref = torch.cat(ys)
y_ref.backward(dy_ref)
# Construct operations
recipe = make_recipe(quantization)
with te.quantized_model_init(enabled=with_quantization, recipe=recipe):
fc1 = te_ops.GroupedLinear(
group_size,
hidden_size,
2 * hidden_size,
bias=bias,
device=device,
dtype=dtype,
)
fc2 = te_ops.GroupedLinear(
group_size,
hidden_size,
hidden_size,
bias=bias,
device=device,
dtype=dtype,
)
module = te_ops.Sequential(
fc1,
te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size),
fc2,
)
# Copy weights
with torch.no_grad():
for group_idx in range(group_size):
getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx])
getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx])
if bias:
getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx])
getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx])
del fc1_ws_test, fc1_bs_test, fc2_ws_test, fc2_bs_test
# Fuse ops and perform forward and backward pass
with te.autocast(enabled=with_quantization, recipe=recipe):
y_test = module(x_test, split_sizes, probs_test, split_sizes)
y_test.backward(dy_test)
# Loose tols for sanity checking
tols = {"rtol": 0.125, "atol": 0.25}
if quantization == "nvfp4":
tols = {"rtol": 0.25, "atol": 0.5}
# Check values
assert_close(y_test, y_ref, **tols)
assert_close_grads(x_test, x_ref, **tols)
assert_close_grads(probs_test, probs_ref, **tols)
for group_idx in range(group_size):
assert_close_grads(getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols)
assert_close_grads(getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols)
assert_close_grads(getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols)
assert_close_grads(getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols)
class TestCustomOps:
"""Test with ops that are defined externally"""
def test_custom_basic_op(
self,
*,
shape: Iterable[int] = (7, 5),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
) -> None:
"""Custom basic op"""
class CustomScaleOp(te.ops.BasicOperation):
"""Custom op that applies a learnable scale"""
def __init__(self) -> None:
super().__init__()
self.scale: torch.nn.Parameter
scale = torch.ones((), dtype=dtype, device=device)
scale = torch.nn.Parameter(scale)
self.register_parameter("scale", scale)
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
ctx.save_for_backward(self.scale, input_)
return self.scale * input_
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> torch.Tensor:
(
scale,
input_,
) = ctx.saved_tensors
grad_scale = torch.inner(input_.reshape(-1), grad_output.reshape(-1))
grad_scale = grad_scale.reshape(())
grad_input = scale * grad_output
return grad_input, (grad_scale,)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = w_ref * x_ref
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = CustomScaleOp()
forward = te.ops.Sequential(te.ops.Identity(), op, te.ops.Identity())
with torch.no_grad():
op.scale.copy_(w_test)
del w_test
y_test = forward(x_test)
y_test.backward(dy_test)
# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.scale.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
def test_custom_forward_fused_op(
self,
*,
shape: Iterable[int] = (7, 11),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
"""Custom fused op in forward pass"""
class CustomForwardLinearSiLU(te.ops.FusedOperation):
"""Custom fused op for GEMM + SiLU"""
_enabled = True
def __init__(self, *, linear, silu) -> None:
super().__init__((linear, silu))
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
**unused,
) -> torch.Tensor:
weight = self.basic_ops[0].weight
dtype = weight.dtype
device = weight.device
# Perform compute on CPU, because why not?
x = input_.cpu()
w = weight.cpu()
y = torch.matmul(x, w.T)
z = torch.nn.functional.silu(y)
out = z.to(device=device)
# Save state for linear backward
linear_op_ctx = basic_op_ctxs[0]
linear_op_ctx.save_for_backward(input_, weight)
linear_op_ctx.with_quantized_compute = False
linear_op_ctx.input_quantizer = None
linear_op_ctx.weight_quantizer = None
linear_op_ctx.grad_output_quantizer = None
linear_op_ctx.grad_input_quantizer = None
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = True
linear_op_ctx.weight_requires_grad = True
# Save state for SiLU backward
silu_op_ctx = basic_op_ctxs[1]
silu_op_ctx.save_for_backward(y.to(device=device))
silu_op_ctx.dtype = dtype
silu_op_ctx.prev_op_grad_output_quantizer = None
return out, [(), ()]
@staticmethod
def fuse_ops(
ops: list[FusibleOperation],
**unused,
) -> list[FusibleOperation]:
"""Apply fusion the first time this function is called"""
if CustomForwardLinearSiLU._enabled:
CustomForwardLinearSiLU._enabled = False
op = CustomForwardLinearSiLU(linear=ops[0], silu=ops[1])
return [op] + ops[2:]
return ops
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(shape[-1], shape[-1]),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
y_ref = torch.nn.functional.silu(y_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operation
te.ops.register_forward_fusion(CustomForwardLinearSiLU.fuse_ops)
model = te.ops.Sequential(
te.ops.Linear(shape[-1], shape[-1], bias=False),
te.ops.SiLU(),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
del w_test
y_test = model(x_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
forward_ops = model._module_groups[0]._forward_ops
assert len(forward_ops) == 1
assert isinstance(forward_ops[0][0], CustomForwardLinearSiLU)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
def test_custom_backward_fused_op(
self,
*,
shape: Iterable[int] = (13, 5),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
"""Custom fused op in backward pass"""
class CustomBackwardLinearScale(te.ops.FusedOperation):
"""Custom fused op for backward linear + scale"""
_enabled: bool = True
def __init__(self, *, scale, linear) -> None:
super().__init__((scale, linear))
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
**unused,
) -> torch.Tensor:
# Load state from linear forward
linear_op_ctx = basic_op_ctxs[1]
x, w = linear_op_ctx.saved_tensors
dtype = linear_op_ctx.dtype
device = w.device
# Perform compute in FP64 and apply scale before dgrad
# GEMM instead of after
scale = self.basic_ops[0].scale
dy = grad_output.double()
x = x.double()
w = w.double()
dx = torch.matmul(dy, scale * w)
dw = torch.matmul(dy.T, x)
dx = dx.to(dtype=dtype)
dw = dw.to(dtype=dtype)
return dx, [(), (dw,)], [(), ()]
@staticmethod
def fuse_ops(
ops: list[FusibleOperation],
**unused,
) -> list[FusibleOperation]:
"""Apply fusion the first time this function is called"""
if CustomBackwardLinearScale._enabled:
CustomBackwardLinearScale._enabled = False
op = CustomBackwardLinearScale(scale=ops[0], linear=ops[1])
return [op] + ops[2:]
return ops
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(shape[-1], shape[-1]),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
scale = 1.234
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(scale * x_ref, w_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operation
te.ops.register_backward_fusion(CustomBackwardLinearScale.fuse_ops, prepend=True)
model = te.ops.Sequential(
te.ops.ConstantScale(scale),
te.ops.Linear(shape[-1], shape[-1], bias=False),
)
with torch.no_grad():
model[1].weight.copy_(w_test)
del w_test
y_test = model(x_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(backward_ops[0][0], CustomBackwardLinearScale)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for GroupedTensor class"""
from typing import List, Tuple
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
from transformer_engine.pytorch import (
Quantizer,
Float8Quantizer,
Float8CurrentScalingQuantizer,
Float8BlockQuantizer,
MXFP8Quantizer,
NVFP4Quantizer,
)
from transformer_engine.pytorch.constants import TE_DType_To_Torch
import transformer_engine_torch as tex
# Check available recipes
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
return_reason=True
)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
_quantization_params = [
pytest.param(
"fp8_delayed_scaling",
marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8),
),
pytest.param(
"fp8_current_scaling",
marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8),
),
pytest.param(
"fp8_blockwise",
marks=pytest.mark.skipif(
not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling
),
),
pytest.param(
"mxfp8",
marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8),
),
pytest.param(
"nvfp4",
marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4),
),
]
def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer:
"""Create quantizers for given quantization scheme"""
if quantization == "fp8_delayed_scaling":
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device="cuda"),
amax=torch.zeros(1, dtype=torch.float32, device="cuda"),
fp8_dtype=tex.DType.kFloat8E4M3,
)
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device="cuda",
)
quantizer.set_usage(rowwise=True, columnwise=False)
elif quantization == "fp8_blockwise":
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=False,
force_pow_2_scales=True,
amax_epsilon=0.0,
block_scaling_dim=1,
)
elif quantization == "mxfp8":
quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
elif quantization == "nvfp4":
quantizer = NVFP4Quantizer(
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)
else:
raise ValueError(f"Unknown quantization scheme: {quantization}")
quantizer.internal = False
return quantizer
def _get_rowwise_data_tensor(qtensor, quantization: str) -> torch.Tensor:
if quantization in ("fp8_delayed_scaling", "fp8_current_scaling"):
return qtensor._data
if quantization in ("fp8_blockwise", "mxfp8", "nvfp4"):
return qtensor._rowwise_data
raise ValueError(f"Unknown quantization scheme: {quantization}")
def _rowwise_offset_bytes(numel: int, quantization: str) -> int:
if quantization == "nvfp4":
return numel // 2
return numel
class TestGroupedTensor:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def test_basic_construction_all_same_shape(self) -> None:
"""Test GroupedTensor construction with all tensors having same shape"""
num_tensors = 4
shape = [(256, 512) for _ in range(num_tensors)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
assert grouped_tensor.num_tensors == num_tensors
assert grouped_tensor.all_same_shape()
assert grouped_tensor.all_same_first_dim()
assert grouped_tensor.all_same_last_dim()
assert grouped_tensor.logical_shape == (num_tensors * 256, 512)
assert grouped_tensor.get_common_first_dim() == 256
assert grouped_tensor.get_common_last_dim() == 512
assert grouped_tensor.has_data()
def test_basic_construction_varying_first_dim(self) -> None:
"""Test GroupedTensor construction with varying first dimension"""
num_tensors = 3
shape = [(128, 512), (256, 512), (384, 512)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
assert grouped_tensor.num_tensors == num_tensors
assert not grouped_tensor.all_same_shape()
assert not grouped_tensor.all_same_first_dim()
assert grouped_tensor.all_same_last_dim()
assert grouped_tensor.get_common_last_dim() == shape[0][1]
assert grouped_tensor.logical_shape == (
sum(v for v, _ in shape),
shape[0][1],
) # sum of first dims
def test_split_into_quantized_tensors_no_quantization(self) -> None:
"""Test split_into_quantized_tensors for unquantized tensors"""
num_tensors = 3
shape = [(256, 512) for _ in range(num_tensors)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
# Get the original data pointer
original_data_ptr = grouped_tensor.data.data_ptr()
# Split into tensors
tensors = grouped_tensor.split_into_quantized_tensors()
assert len(tensors) == num_tensors
# Verify each tensor has correct shape and shares storage
for i, tensor in enumerate(tensors):
assert tensor.shape == shape[i]
assert isinstance(tensor, torch.Tensor)
assert not hasattr(tensor, "_data") # Not a quantized tensor
# Verify data pointer is within the original grouped tensor storage
# The tensor should be a view of the original data
assert tensor.data_ptr() >= original_data_ptr
# Calculate expected offset
expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size()
assert tensor.data_ptr() == original_data_ptr + expected_offset
@pytest.mark.parametrize("quantization", _quantization_params)
def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None:
"""Test split_into_quantized_tensors for quantized tensors"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizers = make_quantizer(quantization, num_tensors, shape)
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizers,
device="cuda",
)
# Get the original data pointer
original_data_ptr = grouped_tensor.data.data_ptr()
# Split into tensors
tensors = grouped_tensor.split_into_quantized_tensors()
assert len(tensors) == num_tensors
# Verify each tensor shares storage with the grouped tensor
for i, tensor in enumerate(tensors):
rowwise_data = _get_rowwise_data_tensor(tensor, quantization)
assert rowwise_data is not None
assert rowwise_data.data_ptr() >= original_data_ptr
numel = shape[i][0] * shape[i][1]
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
def test_split_varying_shapes(self) -> None:
"""Test split_into_quantized_tensors with varying shapes"""
num_tensors = 3
shape = [(128, 512), (256, 512), (384, 512)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
original_data_ptr = grouped_tensor.data.data_ptr()
tensors = grouped_tensor.split_into_quantized_tensors()
assert len(tensors) == num_tensors
# Verify shapes and storage
cumulative_offset = 0
for i, tensor in enumerate(tensors):
assert tensor.shape == shape[i]
expected_offset = cumulative_offset * tensor.element_size()
assert tensor.data_ptr() == original_data_ptr + expected_offset
cumulative_offset += shape[i][0] * shape[i][1]
@pytest.mark.parametrize("quantization", _quantization_params)
def test_quantize_inplace(self, quantization: str) -> None:
"""Test that quantize is done in-place for all recipes"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizers = make_quantizer(quantization, num_tensors, shape)
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizers,
device="cuda",
)
# Get original data pointers before quantization
original_data_ptr = grouped_tensor.data.data_ptr()
original_scale_inv_ptr = grouped_tensor.scale_inv.data_ptr()
original_scale_ptr = (
grouped_tensor.scale.data_ptr() if grouped_tensor.scale is not None else None
)
# Create input tensors
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
# Quantize in place
quantized_tensors = grouped_tensor.quantize(input_tensors)
# Verify data pointers haven't changed (in-place operation)
assert grouped_tensor.data.data_ptr() == original_data_ptr
assert grouped_tensor.scale_inv.data_ptr() == original_scale_inv_ptr
if original_scale_ptr is not None:
assert grouped_tensor.scale.data_ptr() == original_scale_ptr
# Verify returned tensors point to the same storage
for i, qtensor in enumerate(quantized_tensors):
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
numel = shape[i][0] * shape[i][1]
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
@pytest.mark.parametrize("quantization", _quantization_params)
def test_quantize_varying_shapes(self, quantization: str) -> None:
"""Test quantize with varying shapes"""
num_tensors = 3
shape = [(256, 512), (512, 512), (768, 512)]
quantizers = make_quantizer(quantization, num_tensors, shape)
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizers,
device="cuda",
)
# Get original data pointers
original_data_ptr = grouped_tensor.data.data_ptr()
# Create input tensors with varying shapes
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
# Quantize in place
quantized_tensors = grouped_tensor.quantize(input_tensors)
# Verify data pointer hasn't changed
assert grouped_tensor.data.data_ptr() == original_data_ptr
# Verify each tensor points to correct location
cumulative_numel = 0
for qtensor, tensor_shape in zip(quantized_tensors, shape):
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
cumulative_numel += tensor_shape[0] * tensor_shape[1]
@pytest.mark.parametrize("quantization", _quantization_params)
def test_static_quantize_method(self, quantization: str) -> None:
"""Test the static quantize method"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizers = make_quantizer(quantization, num_tensors, shape)
# Create input tensors
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
# Use static quantize method
grouped_tensor = GroupedTensor.create_and_quantize(
tensors=input_tensors,
quantizer=quantizers,
device="cuda",
)
# Verify the grouped tensor was created correctly
assert grouped_tensor.num_tensors == num_tensors
assert grouped_tensor.has_data()
# Verify quantized_tensors were created and point to same storage
assert grouped_tensor.quantized_tensors is not None
assert len(grouped_tensor.quantized_tensors) == num_tensors
original_data_ptr = grouped_tensor.data.data_ptr()
for i, qtensor in enumerate(grouped_tensor.quantized_tensors):
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
numel = shape[i][0] * shape[i][1]
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
def test_clear(self) -> None:
"""Test clear method"""
num_tensors = 3
shape = [(256, 512) for _ in range(num_tensors)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
assert grouped_tensor.has_data()
assert grouped_tensor.num_tensors == num_tensors
grouped_tensor.clear()
assert not grouped_tensor.has_data()
assert grouped_tensor.num_tensors == 0
assert grouped_tensor.data is None
assert grouped_tensor.logical_shape == (0, 0)
...@@ -94,6 +94,7 @@ all_boolean = [True, False] ...@@ -94,6 +94,7 @@ all_boolean = [True, False]
all_activations = [ all_activations = [
"gelu", "gelu",
"geglu", "geglu",
"glu",
"qgelu", "qgelu",
"qgeglu", "qgeglu",
"relu", "relu",
...@@ -484,6 +485,7 @@ class TorchGroupedLinearWithPadding(nn.Module): ...@@ -484,6 +485,7 @@ class TorchGroupedLinearWithPadding(nn.Module):
_supported_act = { _supported_act = {
"gelu": nn.GELU(approximate="tanh"), "gelu": nn.GELU(approximate="tanh"),
"geglu": nn.GELU(approximate="tanh"), "geglu": nn.GELU(approximate="tanh"),
"glu": nn.Sigmoid(),
"qgelu": TorchQuickGELU(), "qgelu": TorchQuickGELU(),
"qgeglu": TorchQuickGELU(), "qgeglu": TorchQuickGELU(),
"relu": nn.ReLU(), "relu": nn.ReLU(),
......
...@@ -745,6 +745,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation): ...@@ -745,6 +745,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation):
_test_export_layernorm_mlp(activation=activation) _test_export_layernorm_mlp(activation=activation)
# Quantization recipes with fp8_dpa=True for attention emulation export test
dpa_quantization_recipes = [None] # None = no quantization
if fp8_available:
dpa_quantization_recipes.append(recipe.DelayedScaling(fp8_dpa=True))
dpa_quantization_recipes.append(recipe.Float8CurrentScaling(fp8_dpa=True))
@pytest.mark.parametrize("fp8_recipe", dpa_quantization_recipes)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision, use_mask, attn_mask_type", "precision, use_mask, attn_mask_type",
[ [
...@@ -762,6 +770,7 @@ def test_export_core_attention( ...@@ -762,6 +770,7 @@ def test_export_core_attention(
precision: torch.dtype, precision: torch.dtype,
use_mask: bool, use_mask: bool,
attn_mask_type: str, attn_mask_type: str,
fp8_recipe: recipe.Recipe,
): ):
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip") pytest.skip("ONNX is not currently required in hip")
...@@ -783,22 +792,25 @@ def test_export_core_attention( ...@@ -783,22 +792,25 @@ def test_export_core_attention(
mask_str = get_attn_mask_str(use_mask, attn_mask_type) mask_str = get_attn_mask_str(use_mask, attn_mask_type)
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
fname = f"te.core_attention{mask_str}{high_prec_str}.onnx" fp8_str = "_fp8_dpa" if fp8_recipe is not None else ""
fname = f"te.core_attention{fp8_str}{mask_str}{high_prec_str}.onnx"
is_fp8 = fp8_recipe is not None
model = te.attention.DotProductAttention( model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
kv_channels=kv_channels, kv_channels=kv_channels,
attention_dropout=0.5,
qkv_format=qkv_format, qkv_format=qkv_format,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
).to(device="cuda") ).to(device="cuda")
do_export(model, inp, fname, input_names=input_names, fp8_recipe=None) do_export(model, inp, fname, input_names=input_names, fp8_recipe=fp8_recipe)
te_outputs = te_infer(model, inp, is_fp8=False, fp8_recipe=None) te_outputs = te_infer(model, inp, is_fp8=is_fp8, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision in (torch.bfloat16,): if precision in (torch.bfloat16,):
return return
atol = 5e-1 if is_fp8 else 1e-2
validate_result( validate_result(
fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs fname, inp, model, is_fp8=True, atol=atol, input_names=input_names, te_outputs=te_outputs
) )
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from typing import Optional from typing import Optional, List
import torch import torch
import pytest import pytest
...@@ -114,6 +114,7 @@ batch_sizes_with_zero = [0, 1, 2] ...@@ -114,6 +114,7 @@ batch_sizes_with_zero = [0, 1, 2]
all_activations = [ all_activations = [
"gelu", "gelu",
"geglu", "geglu",
"glu",
"qgelu", "qgelu",
"qgeglu", "qgeglu",
"relu", "relu",
...@@ -138,6 +139,117 @@ def reset_global_fp8_state(): ...@@ -138,6 +139,117 @@ def reset_global_fp8_state():
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
def check_grouped_tensor_pointers_helper(tensors, num_elems_in_byte=1, tensor_name="tensor"):
"""
Verify that tensors are stored in contiguous memory.
Args:
tensors: List or iterable of tensors to check
num_elems_in_byte: Number of elements packed per byte (1 for normal, 2 for NVFP4)
tensor_name: Name to use in error messages
"""
tensor_list = list(tensors)
if len(tensor_list) < 2:
return # Nothing to check
for i in range(1, len(tensor_list)):
prev_tensor = tensor_list[i - 1]
curr_tensor = tensor_list[i]
# Calculate expected offset based on previous tensor size
prev_numel = prev_tensor.numel()
expected_offset = (prev_numel // num_elems_in_byte) * prev_tensor.element_size()
# Verify current tensor's data pointer is correctly offset
expected_ptr = prev_tensor.data_ptr() + expected_offset
actual_ptr = curr_tensor.data_ptr()
assert (
actual_ptr == expected_ptr
), f"{tensor_name} {i} data pointer mismatch: expected {expected_ptr}, got {actual_ptr}"
def check_grouped_tensor_pointers(
weights: List[torch.Tensor], fp8_recipe: Optional[recipe.Recipe] = None
):
"""
Verify that the pointers of the weights are in contiguous memory for GroupedTensor.
TODO(ksivaman): This check can be made way more efficient but for now leaving the brute force approach.
"""
num_elems_in_a_data_byte = 1 if fp8_recipe is None else 2 if fp8_recipe.nvfp4() else 1
# Check data.
if hasattr(weights[0], "_data") and weights[0]._data is not None:
data_tensors = [w._data for w in weights]
check_grouped_tensor_pointers_helper(data_tensors, num_elems_in_byte=1, tensor_name="data")
# Check transpose.
if hasattr(weights[0], "_transpose") and weights[0]._transpose is not None:
transpose_tensors = [w._transpose for w in weights]
check_grouped_tensor_pointers_helper(
transpose_tensors, num_elems_in_byte=1, tensor_name="transpose"
)
# Check scale_inv.
if hasattr(weights[0], "_scale_inv") and weights[0]._scale_inv is not None:
scale_inv_tensors = [w._scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
scale_inv_tensors, num_elems_in_byte=1, tensor_name="scale_inv"
)
# Check rowwise scale_inv.
if hasattr(weights[0], "_rowwise_scale_inv") and weights[0]._rowwise_scale_inv is not None:
scale_inv_tensors = [w._rowwise_scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
scale_inv_tensors, num_elems_in_byte=1, tensor_name="rowwise_scale_inv"
)
# Check columnwise scale_inv.
if (
hasattr(weights[0], "_columnwise_scale_inv")
and weights[0]._columnwise_scale_inv is not None
):
columnwise_scale_inv_tensors = [w._columnwise_scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_scale_inv_tensors,
num_elems_in_byte=1,
tensor_name="columnwise scale_inv",
)
# Check rowwise amax.
if hasattr(weights[0], "_rowwise_amax") and weights[0]._rowwise_amax is not None:
rowwise_amax_tensors = [w._rowwise_amax for w in weights]
check_grouped_tensor_pointers_helper(
rowwise_amax_tensors, num_elems_in_byte=1, tensor_name="rowwise amax"
)
# Check columnwise amax.
if hasattr(weights[0], "_columnwise_amax") and weights[0]._columnwise_amax is not None:
columnwise_amax_tensors = [w._columnwise_amax for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_amax_tensors, num_elems_in_byte=1, tensor_name="columnwise amax"
)
# Check rowwise data.
if hasattr(weights[0], "_rowwise_data") and weights[0]._rowwise_data is not None:
rowwise_data_tensors = [w._rowwise_data for w in weights]
check_grouped_tensor_pointers_helper(
rowwise_data_tensors,
num_elems_in_byte=num_elems_in_a_data_byte,
tensor_name="rowwise data",
)
# Check columnwise data.
if hasattr(weights[0], "_columnwise_data") and weights[0]._columnwise_data is not None:
columnwise_data_tensors = [w._columnwise_data for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_data_tensors,
num_elems_in_byte=num_elems_in_a_data_byte,
tensor_name="columnwise data",
)
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size), (config.max_seqlen_q, config.batch_size, config.hidden_size),
...@@ -486,10 +598,19 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -486,10 +598,19 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("single_param", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"]) @pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4]) @pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear( def test_sanity_grouped_linear(
dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split dtype,
bs,
model,
fp8_recipe,
fp8_model_params,
use_bias,
single_param,
num_gemms,
empty_split,
): ):
if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params: if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
pytest.skip("FP8 model parameters are not supported in debug mode.") pytest.skip("FP8 model parameters are not supported in debug mode.")
...@@ -499,6 +620,9 @@ def test_sanity_grouped_linear( ...@@ -499,6 +620,9 @@ def test_sanity_grouped_linear(
bs = bs * 16 bs = bs * 16
num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
if single_param:
os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] = "1"
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
...@@ -508,9 +632,19 @@ def test_sanity_grouped_linear( ...@@ -508,9 +632,19 @@ def test_sanity_grouped_linear(
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
te_grouped_linear = GroupedLinear( te_grouped_linear = GroupedLinear(
num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype num_gemms,
config.hidden_size,
ffn_hidden_size,
bias=use_bias,
params_dtype=dtype,
).cuda() ).cuda()
# Verify that weights are stored in contiguous GroupedTensor storage.
weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)]
if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()):
if single_param:
check_grouped_tensor_pointers(weights, fp8_recipe)
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
...@@ -528,6 +662,9 @@ def test_sanity_grouped_linear( ...@@ -528,6 +662,9 @@ def test_sanity_grouped_linear(
loss.backward() loss.backward()
assert out.shape == (num_tokens, ffn_hidden_size) assert out.shape == (num_tokens, ffn_hidden_size)
if single_param:
del os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"]
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
...@@ -1005,7 +1142,13 @@ def test_replace_raw_data_for_float8tensor(): ...@@ -1005,7 +1142,13 @@ def test_replace_raw_data_for_float8tensor():
random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda") random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor) fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)
attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"] attrs_to_check = [
"_quantizer",
"_fp8_dtype",
"_scale_inv",
"_transpose",
"_transpose_invalid",
]
attrs = {} attrs = {}
for attr in attrs_to_check: for attr in attrs_to_check:
attrs[attr] = getattr(fp8_tensor, attr) attrs[attr] = getattr(fp8_tensor, attr)
......
...@@ -15,7 +15,7 @@ import torch ...@@ -15,7 +15,7 @@ import torch
import transformer_engine import transformer_engine
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import InferenceParams from transformer_engine.pytorch import InferenceParams, QuantizedTensor
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.attention.dot_product_attention.utils import ( from transformer_engine.pytorch.attention.dot_product_attention.utils import (
get_attention_backend, get_attention_backend,
...@@ -353,11 +353,56 @@ def get_available_attention_backends( ...@@ -353,11 +353,56 @@ def get_available_attention_backends(
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
if AttentionLogging._is_logging_setup is False: if AttentionLogging._is_logging_setup is False:
AttentionLogging.setup_logging() AttentionLogging.setup_logging()
with logging_context(highest_level=AttentionLogging._log_level):
for i in range(3): for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
available_backends, flash_attention_backend, fused_attention_backend = test() available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]: if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend) fused_attn_backends.append(fused_attention_backend)
return available_backends, flash_attention_backend, fused_attn_backends return available_backends, flash_attention_backend, fused_attn_backends
@torch.no_grad
def assert_close(
actual: Optional[torch.Tensor],
expected: Optional[torch.Tensor],
*,
check_device: bool = False,
check_dtype: bool = False,
check_layout: bool = False,
**kwargs,
) -> None:
"""Assert that two tensors are close.
This function is a wrapper around torch.testing.assert_close. It
changes the defaults for device and dtype checks (useful when the
reference implementation is computed in high precision on CPU) and
it can handle quantized tensors.
"""
if isinstance(actual, QuantizedTensor):
actual = actual.dequantize()
if isinstance(expected, QuantizedTensor):
expected = expected.dequantize()
torch.testing.assert_close(
actual,
expected,
check_device=check_device,
check_dtype=check_dtype,
check_layout=check_layout,
**kwargs,
)
def assert_close_grads(
actual: Optional[torch.Tensor],
expected: Optional[torch.Tensor],
**kwargs,
) -> None:
"""Assert that two tensors have close gradients."""
if actual is None and expected is None:
return
assert actual is not None
assert expected is not None
assert_close(actual.grad, expected.grad, **kwargs)
...@@ -202,6 +202,7 @@ if(USE_CUDA) ...@@ -202,6 +202,7 @@ if(USE_CUDA)
fused_attn/fused_attn_fp8.cu fused_attn/fused_attn_fp8.cu
fused_attn/utils.cu fused_attn/utils.cu
gemm/cublaslt_gemm.cu gemm/cublaslt_gemm.cu
gemm/cublaslt_grouped_gemm.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
...@@ -225,15 +226,18 @@ if(USE_CUDA) ...@@ -225,15 +226,18 @@ if(USE_CUDA)
list(APPEND transformer_engine_cuda_arch_specific_sources list(APPEND transformer_engine_cuda_arch_specific_sources
activation/gelu.cu activation/gelu.cu
activation/glu.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu activation/swiglu.cu
cast/cast.cu cast/cast.cu
gemm/cutlass_grouped_gemm.cu gemm/cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform.cu hadamard_transform/group_hadamard_transform.cu
hadamard_transform/graph_safe_group_hadamard_transform.cu
hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_square_blockwise.cu
...@@ -357,6 +361,7 @@ else() ...@@ -357,6 +361,7 @@ else()
fused_attn/kv_cache.cu fused_attn/kv_cache.cu
fused_attn/utils.cu fused_attn/utils.cu
gemm/cublaslt_gemm.cu gemm/cublaslt_gemm.cu
gemm/cublaslt_grouped_gemm.cu
gemm/hipblas_gemm.cu gemm/hipblas_gemm.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu normalization/layernorm/ln_fwd_cuda_kernel.cu
...@@ -381,6 +386,7 @@ else() ...@@ -381,6 +386,7 @@ else()
list(APPEND transformer_engine_cuda_arch_specific_sources list(APPEND transformer_engine_cuda_arch_specific_sources
activation/gelu.cu activation/gelu.cu
activation/glu.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu activation/swiglu.cu
cast/cast.cu cast/cast.cu
...@@ -476,20 +482,18 @@ endif() ...@@ -476,20 +482,18 @@ endif()
option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF) option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF)
if (NVTE_WITH_CUBLASMP) if (NVTE_WITH_CUBLASMP)
target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP)
target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include) target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include)
find_library(CUBLASMP_LIB find_library(CUBLASMP_LIB
NAMES cublasmp libcublasmp NAMES cublasmp libcublasmp
PATHS ${CUBLASMP_DIR} PATHS ${CUBLASMP_DIR}
PATH_SUFFIXES lib PATH_SUFFIXES lib
REQUIRED) REQUIRED)
find_library(NVSHMEM_HOST_LIB find_library(NCCL_LIB
NAMES nvshmem_host libnvshmem_host.so.3 NAMES nccl libnccl
PATHS ${NVSHMEM_DIR}
PATH_SUFFIXES lib PATH_SUFFIXES lib
REQUIRED) REQUIRED)
target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB}) target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB} ${CUBLASMP_LIB})
message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}")
message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}")
endif() endif()
if (USE_CUDA) if (USE_CUDA)
...@@ -561,6 +565,7 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu ...@@ -561,6 +565,7 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
list(APPEND nvte_sources_with_fast_math activation/gelu.cu list(APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/glu.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu) activation/swiglu.cu)
endif() endif()
......
...@@ -246,11 +246,13 @@ def _nvidia_cudart_include_dir() -> str: ...@@ -246,11 +246,13 @@ def _nvidia_cudart_include_dir() -> str:
return "" return ""
# Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia" # Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia"
# above doesn't through. However, they don't set "__file__" attribute. # above doesn't throw. However, they don't set "__file__" attribute.
if nvidia.__file__ is None: if nvidia.__file__ is not None:
return "" nvidia_root = Path(nvidia.__file__).parent
else:
nvidia_root = Path(nvidia.__path__[0]) # namespace package
include_dir = Path(nvidia.__file__).parent / "cuda_runtime" include_dir = nvidia_root / "cuda_runtime"
return str(include_dir) if include_dir.exists() else "" return str(include_dir) if include_dir.exists() else ""
......
...@@ -13,6 +13,14 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { ...@@ -13,6 +13,14 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream); act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
} }
void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_gelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, gelu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dgelu); NVTE_API_CALL(nvte_dgelu);
...@@ -20,6 +28,20 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output ...@@ -20,6 +28,20 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(grad, input, output, stream); dact_fn<fp32, Empty, dgelu<fp32, fp32>>(grad, input, output, stream);
} }
void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dgelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace, NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
...@@ -33,6 +55,20 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati ...@@ -33,6 +55,20 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati
input, activation_input, output, dbias, workspace, nullptr, stream); input, activation_input, output, dbias, workspace, nullptr, stream);
} }
void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu); NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -54,6 +90,15 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) ...@@ -54,6 +90,15 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream); act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
} }
void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_qgelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, qgelu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgelu); NVTE_API_CALL(nvte_dqgelu);
...@@ -61,6 +106,20 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu ...@@ -61,6 +106,20 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(grad, input, output, stream); dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(grad, input, output, stream);
} }
void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dqgelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace, NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
...@@ -74,6 +133,20 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat ...@@ -74,6 +133,20 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat
input, activation_input, output, dbias, workspace, nullptr, stream); input, activation_input, output, dbias, workspace, nullptr, stream);
} }
void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu); NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine; using namespace transformer_engine;
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../util/math.h"
#include "./activation_template.h"
void nvte_glu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_glu);
using namespace transformer_engine;
Empty e = {};
gated_act_fn<fp32, Empty, sigmoid<fp32, fp32>>(input, output, e, stream);
}
void nvte_dglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dglu);
using namespace transformer_engine;
Empty e = {};
dgated_act_fn<fp32, Empty, sigmoid<fp32, fp32>, dsigmoid<fp32, fp32>>(grad, input, output, e,
stream);
}
...@@ -13,6 +13,14 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { ...@@ -13,6 +13,14 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream); act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream);
} }
void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_relu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, relu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_drelu); NVTE_API_CALL(nvte_drelu);
...@@ -20,6 +28,20 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output ...@@ -20,6 +28,20 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, drelu<fp32, fp32>>(grad, input, output, stream); dact_fn<fp32, Empty, drelu<fp32, fp32>>(grad, input, output, stream);
} }
void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_drelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, drelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace, NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
...@@ -33,6 +55,20 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati ...@@ -33,6 +55,20 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati
input, activation_input, output, dbias, workspace, nullptr, stream); input, activation_input, output, dbias, workspace, nullptr, stream);
} }
void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_drelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, drelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu); NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -54,6 +90,15 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) ...@@ -54,6 +90,15 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream); act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream);
} }
void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_srelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, srelu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dsrelu); NVTE_API_CALL(nvte_dsrelu);
...@@ -61,6 +106,20 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu ...@@ -61,6 +106,20 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(grad, input, output, stream); dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(grad, input, output, stream);
} }
void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dsrelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsrelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace, NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
...@@ -74,6 +133,20 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat ...@@ -74,6 +133,20 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
input, activation_input, output, dbias, workspace, nullptr, stream); input, activation_input, output, dbias, workspace, nullptr, stream);
} }
void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dsrelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsrelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu); NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -13,6 +13,14 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { ...@@ -13,6 +13,14 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream); act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream);
} }
void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_silu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, silu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dsilu); NVTE_API_CALL(nvte_dsilu);
...@@ -20,6 +28,20 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output ...@@ -20,6 +28,20 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, dsilu<fp32, fp32>>(grad, input, output, stream); dact_fn<fp32, Empty, dsilu<fp32, fp32>>(grad, input, output, stream);
} }
void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dsilu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace, NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
...@@ -33,6 +55,20 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati ...@@ -33,6 +55,20 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati
input, activation_input, output, dbias, workspace, nullptr, stream); input, activation_input, output, dbias, workspace, nullptr, stream);
} }
void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dsilu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu); NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -28,6 +28,15 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea ...@@ -28,6 +28,15 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream); dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
} }
void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize);
using namespace transformer_engine;
constexpr bool IS_ACT = false;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
}
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_noop); NVTE_API_CALL(nvte_quantize_noop);
...@@ -62,6 +71,19 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d ...@@ -62,6 +71,19 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
input, activation_input, output, dbias, workspace, nullptr, stream); input, activation_input, output, dbias, workspace, nullptr, stream);
} }
void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = false;
constexpr const NVTEGroupedTensor activation_input = nullptr;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, nullptr>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dequantize); NVTE_API_CALL(nvte_dequantize);
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -37,6 +37,12 @@ inline bool dimensions_supported_by_TMA(const Tensor *const t) { ...@@ -37,6 +37,12 @@ inline bool dimensions_supported_by_TMA(const Tensor *const t) {
return cols % alignment_requirement == 0; return cols % alignment_requirement == 0;
} }
__device__ __forceinline__ unsigned char *align_smem_ptr_per_TMA_requirements(unsigned char *p) {
size_t addr = reinterpret_cast<size_t>(p);
addr = (addr + TMA_SHMEM_ALIGNMENT - 1) & ~(TMA_SHMEM_ALIGNMENT - 1);
return reinterpret_cast<unsigned char *>(addr);
}
namespace kernel { namespace kernel {
constexpr size_t THREADS_PER_BLOCK = 256; constexpr size_t THREADS_PER_BLOCK = 256;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment