Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
......@@ -23,7 +23,7 @@ ALL_DISPATCH_COMBINE_CASES = [
(128, 5, 128, 3),
(1024, 8, 128, 8),
(4096, 32, 1280, 2),
(4096, 256, 4096, 6),
(4096, 64, 4096, 6),
]
DISPATCH_COMBINE_CASES = {
"L0": ALL_DISPATCH_COMBINE_CASES[0:2],
......@@ -44,7 +44,7 @@ ALL_DISPATCH_COMBINE_PADDING_CASES = [
(128, 5, 128, 3, 8),
(1024, 8, 128, 8, 16),
(4096, 32, 1280, 2, 128),
(4096, 256, 4096, 6, 16),
(4096, 64, 4096, 6, 16),
]
DISPATCH_COMBINE_PADDING_CASES = {
"L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:2],
......
......@@ -74,6 +74,14 @@ if not IS_HIP_EXTENSION:
f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}"
)
# Get determinism
_deterministic = (
not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
or torch.are_deterministic_algorithms_enabled()
)
# Reset RNG seed and states
seed = 1234
reset_rng_states()
......@@ -147,6 +155,7 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
......@@ -162,8 +171,10 @@ def test_dot_product_attention(
qkv_layout=qkv_layout,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
is_training = False
available_backends, _, fused_attn_backends = get_available_attention_backends(
......@@ -172,6 +183,7 @@ def test_dot_product_attention(
qkv_layout=qkv_layout,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
......@@ -421,6 +433,15 @@ def test_dpa_softmax(dtype, model_configs, model):
)
@pytest.mark.skipif(get_cudnn_version() < (9, 18, 0), reason="cuDNN 9.18.0+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax_thd(dtype, model_configs, model):
"""Test DotProductAttention module with different softmax types"""
test_dot_product_attention(dtype, model_configs, model, True, True, "thd_thd_thd", False, False)
model_configs_mla = {
#TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
......@@ -685,9 +706,10 @@ model_configs_swa = {
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model):
@pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "sbhd_sbhd_sbhd"])
def test_dpa_sliding_window(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with sliding window attention"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False)
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False)
model_configs_alibi_slopes = {
......@@ -889,11 +911,14 @@ def _run_dot_product_attention(
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
# Create seqlens
......@@ -1295,6 +1320,7 @@ def test_transformer_layer(
qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
),
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
......@@ -1308,6 +1334,7 @@ def test_transformer_layer(
else qkv_format.replace("hd", "3hd")
),
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
......@@ -1435,10 +1462,13 @@ def _run_transformer_layer(
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
# Create input tensor
......@@ -1632,6 +1662,7 @@ def test_dpa_fp8_extra_state(model, dtype):
qkv_dtype=torch.float8_e4m3fn,
qkv_layout="sb3hd",
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported and not flash_attn_supported:
......@@ -1822,6 +1853,7 @@ def test_mha_fp8_vs_f16(
fp8=True,
fp8_meta=fp8_meta,
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported_fp8 < 1:
......@@ -1833,6 +1865,7 @@ def test_mha_fp8_vs_f16(
qkv_dtype=dtype,
qkv_layout=qkv_format.replace("hd", "h3d"),
is_training=is_training,
deterministic=_deterministic,
)
_, fused_attn_supported_f16, _ = available_backends
if not fused_attn_supported_f16:
......@@ -1841,6 +1874,7 @@ def test_mha_fp8_vs_f16(
if flash_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
......@@ -1850,6 +1884,7 @@ def test_mha_fp8_vs_f16(
if fused_attn_supported_fp8:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
......@@ -1859,6 +1894,7 @@ def test_mha_fp8_vs_f16(
if fused_attn_supported_f16:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
......@@ -2071,6 +2107,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
fp8=True,
fp8_meta=fp8_meta,
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported < 1:
......@@ -2081,6 +2118,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
deterministic=_deterministic,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
......@@ -2091,6 +2129,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
if flash_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)")
flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
......@@ -2100,6 +2139,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
if unfused_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)")
unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
......@@ -2108,6 +2148,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
......@@ -2116,6 +2157,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if config.dropout_p == 0.0:
# test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)")
......@@ -2370,13 +2412,16 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
qkv_dtype=torch.float8_e4m3fn,
qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not (fused_attn_backends and unfused_attn_supported):
pytest.skip("Not enough backends to run this test with.")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(
dtype, config, "UnfusedDotProductAttention"
)
atol = 5e-1
rtol = 5e-1
......@@ -2409,10 +2454,13 @@ def _run_custom_mha_fp8(dtype, config, backend):
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
inp = 0.0001 * torch.randint(
......@@ -2463,10 +2511,13 @@ def _run_ref_mha_f16(dtype, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
inp = torch.load("qkv.pt").to(device="cuda")
......@@ -2754,7 +2805,7 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
cu_seqlens,
max_s,
) -> torch.Tensor:
with self.prepare_forward(inp, num_gemms=3) as inp:
with self.prepare_forward_ctx(inp, num_gemms=3) as inp:
out = _custom_mha_fp8.apply(
inp,
self.qkv_weight,
......
......@@ -148,7 +148,7 @@ model_configs_fused_attn = {
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), # MHA
"cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA
"cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
"cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA
"cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
"cp_2_2": ModelConfig(
......@@ -164,7 +164,7 @@ model_configs_fused_attn = {
2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias"
), # GQA
"cp_2_4": ModelConfig(
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512)
), # GQA
"cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA
"cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA
......@@ -188,7 +188,16 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
configs = [
"cp_1_0",
"cp_1_1",
"cp_1_4",
"cp_2_0",
"cp_2_2",
"cp_2_4",
"cp_3_2",
"cp_4_2",
]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"]
......@@ -284,9 +293,14 @@ def test_cp_with_fused_attention(
pytest.skip(
"CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!"
)
if config.softmax_type != "vanilla" and qkv_format == "thd":
if (
get_cudnn_version() < (9, 18, 0)
and config.softmax_type != "vanilla"
and qkv_format == "thd"
):
pytest.skip(
"CP implementation does not support qkv_format=thd for non-vanilla softmax types!"
"Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for"
" non-vanilla softmax types!"
)
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
......
......@@ -15,6 +15,7 @@ from transformer_engine.pytorch import (
is_fp8_available,
is_mxfp8_available,
is_fp8_block_scaling_available,
is_nvfp4_available,
)
from transformer_engine.pytorch.quantization import RecipeState
from transformer_engine.debug.pytorch.debug_state import TEDebugState
......@@ -29,6 +30,7 @@ mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available(
return_reason=True
)
nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True)
LOG_QUANTIZED_CONFIG_BASE = """
log:
......@@ -363,6 +365,124 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
TEDebugState._reset()
# NVFP4 tests
LOG_NVFP4_CONFIG_BASE = """
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogNvfp4TensorStats:
enabled: True
stats: [
{stats}
]
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
def test_nvfp4_numeric(feature_dirs):
"""Test that NVFP4 underflows% and MSE stats are computed correctly with known values."""
if not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
log_nvfp4_config = LOG_NVFP4_CONFIG_BASE.format(stats="underflows%, mse")
with debug_session(log_nvfp4_config, feature_dirs) as log_dir:
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.quantization import RecipeState
recipe_state = RecipeState.create(
recipe.NVFP4BlockScaling(),
mode="forward",
num_quantizers=3,
)
# Create test tensor with known distribution
torch.manual_seed(42)
tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda()
# Add some small values that should underflow to zero in FP4
tensor[0, :16] = 0.0001
quantizer = recipe_state.make_quantizers()[0]
quantized_tensor = quantizer(tensor)
debug_api.transformer_engine.inspect_tensor(
layer_name="test_layer",
tensor_name="activation",
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=quantized_tensor,
columnwise_quantized_tensor=quantized_tensor,
)
debug_api.step()
dequantized_tensor = quantized_tensor.dequantize()
output = read_log(log_dir)
# Validate both stats are present
assert "nvfp4_underflows%" in output, "underflows% stat missing"
assert "nvfp4_mse" in output, "mse stat missing"
# Extract values and validate numerics
underflows_value = None
mse_value = None
for line in output.splitlines():
if "nvfp4_underflows%" in line and "value=" in line:
underflows_value = float(line.split("value=")[1].split()[0])
if "nvfp4_mse" in line and "value=" in line:
mse_value = float(line.split("value=")[1].split()[0])
# Compute expected underflows: non-zero elements that became zero after quantization
orig_nonzero_mask = tensor != 0
dequant_zero_mask = dequantized_tensor == 0
expected_underflows = (
(orig_nonzero_mask & dequant_zero_mask).sum().float() / tensor.numel() * 100
)
# Allow some tolerance
assert underflows_value == pytest.approx(expected_underflows.cpu().item(), abs=1e-4)
# Compute expected MSE
expected_mse = torch.nn.functional.mse_loss(
dequantized_tensor.float(), tensor.float(), reduction="mean"
)
assert mse_value == pytest.approx(expected_mse.cpu().item(), abs=1e-4)
def test_fp8_stats_allows_nvfp4_with_recipe_prefix(feature_dirs):
"""Test that LogFp8TensorStats allows recipe-prefixed stats with NVFP4 for what-if analysis."""
if not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
# Use recipe-prefixed stat with NVFP4 - should work (computes MXFP8 separately)
log_fp8_config = LOG_QUANTIZED_CONFIG_BASE.format(stats="mxfp8_mse")
with debug_session(log_fp8_config, feature_dirs) as log_dir:
model = te.Linear(128, 128, params_dtype=torch.bfloat16)
inp = torch.randn(128, 128, dtype=torch.bfloat16).cuda()
# Should work - recipe-prefixed stats compute MXFP8 separately for comparison
for _ in range(2):
with te.autocast(recipe=recipe.NVFP4BlockScaling()):
output = model(inp)
loss = output.sum()
loss.backward()
debug_api.step()
output = read_log(log_dir)
# Should have logged MXFP8 MSE stat (what-if scenario)
assert "mxfp8_mse" in output
def test_log_grouped_gemm(feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
......
......@@ -30,10 +30,17 @@ configs = {
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
""",
"log_fp8": """log_fp8:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogFp8TensorStats:
enabled: True
tensors: [activation, gradient, weight]
stats: [underflows, overflows]
stats: [underflows%]
start_step : 0
end_step: 1
""",
......@@ -46,22 +53,26 @@ fake_quant_config:
FakeQuant:
enabled: True
gemms: [fprop, dgrad, wgrad]
tensors: [activation, weight, gradient]
quant_format: FP8E5M2
""",
}
# Configs that require FP8 to be enabled
fp8_required_configs = {"log_fp8"}
def _get_model(model_key):
if model_key == "linear":
return te.Linear(D, D)
return te.Linear(D, D, name="layer")
if model_key == "layernorm_linear":
return te.LayerNormLinear(D, D)
return te.LayerNormLinear(D, D, name="layer")
if model_key == "layernorm_mlp":
return te.LayerNormMLP(D, D, D)
return te.LayerNormMLP(D, D, D, name="layer")
if model_key == "mha_attention":
return te.MultiheadAttention(D, H)
return te.MultiheadAttention(D, H, name="layer")
if model_key == "transformer_layer":
return te.TransformerLayer(D, D, H)
return te.TransformerLayer(D, D, H, name="layer")
def _run_forward_backward(model, fp8):
......@@ -95,4 +106,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
def test_sanity_debug(model_key, fp8, config_key, feature_dirs):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if not fp8 and config_key in fp8_required_configs:
pytest.skip(f"Config '{config_key}' requires FP8")
_run_test(model_key, fp8, configs[config_key], feature_dirs)
......@@ -101,7 +101,7 @@ class TestLoadCheckpoint:
# Path to save checkpoint
if checkpoint_dir is None:
checkpoint_dir = TestLoadCheckpoint._checkpoint_dir()
checkpoint_dir.mkdir(exist_ok=True)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint_file = checkpoint_dir / f"{name}.pt"
# Create module and save checkpoint
......
......@@ -5,8 +5,10 @@
from __future__ import annotations
from collections.abc import Iterable
import functools
import io
import math
import random
from typing import Optional
import pytest
......@@ -37,7 +39,14 @@ from transformer_engine.pytorch import (
import transformer_engine_torch as tex
# Import utility functions
from utils import dtype_tols, make_recipe, quantization_tols, reset_rng_states
from utils import (
assert_close,
assert_close_grads,
dtype_tols,
make_recipe,
quantization_tols,
reset_rng_states,
)
if IS_HIP_EXTENSION:
import os
......@@ -116,6 +125,9 @@ def maybe_skip_quantization(
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
*,
min: float = 0.0,
max: float = 1.0,
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
......@@ -136,7 +148,8 @@ def make_reference_and_test_tensors(
"""
# Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
ref = torch.empty(shape, dtype=ref_dtype, device=ref_device)
ref.uniform_(min, max)
# Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype)
......@@ -1569,7 +1582,19 @@ class TestBasicOps:
@pytest.mark.parametrize(
"activation",
("gelu", "geglu", "qgelu", "qgeglu", "relu", "reglu", "srelu", "sreglu", "silu", "swiglu"),
(
"gelu",
"geglu",
"qgelu",
"qgeglu",
"relu",
"reglu",
"glu",
"srelu",
"sreglu",
"silu",
"swiglu",
),
)
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
@pytest.mark.parametrize("dtype", _dtypes)
......@@ -1589,7 +1614,7 @@ class TestBasicOps:
# Tensor dimensions
in_shape = list(out_shape)
if activation in ("geglu", "qgeglu", "reglu", "sreglu", "swiglu"):
if activation in ("geglu", "glu", "qgeglu", "reglu", "sreglu", "swiglu"):
in_shape[-1] *= 2
# Skip invalid configurations
......@@ -1629,6 +1654,13 @@ class TestBasicOps:
elif activation == "reglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.relu(x1) * x2
elif activation == "sigmoid":
y_ref = torch.nn.functional.sigmoid(x_ref)
elif activation == "glu":
x = x_ref.reshape(*in_shape[:-1], 2, in_shape[-1] // 2)
x = x.flip(-2) # PyTorch GLU swaps gate and linear unit
x = x.reshape(in_shape)
y_ref = torch.nn.functional.glu(x)
elif activation == "srelu":
y_ref = torch.nn.functional.relu(x_ref) ** 2
elif activation == "sreglu":
......@@ -1648,6 +1680,7 @@ class TestBasicOps:
make_op = dict(
gelu=te_ops.GELU,
geglu=te_ops.GEGLU,
glu=te_ops.GLU,
qgelu=te_ops.QGELU,
qgeglu=te_ops.QGEGLU,
relu=te_ops.ReLU,
......@@ -1692,6 +1725,7 @@ class TestBasicOps:
quantization: Optional[str],
quantize_forward: bool,
quantize_backward: bool,
glu_interleave_size: Optional[int] = None,
):
# Tensor dimensions
......@@ -1718,7 +1752,17 @@ class TestBasicOps:
)
# Plain PyTorch implementation
x1, x2 = x_ref.chunk(2, dim=-1)
x = x_ref
if glu_interleave_size is not None:
x = x.reshape(
*in_shape[:-1],
in_shape[-1] // (2 * glu_interleave_size),
2,
glu_interleave_size,
)
x = x.transpose(-3, -2)
x = x.reshape(in_shape)
x1, x2 = x.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2
y_ref.backward(dy_ref)
......@@ -1726,7 +1770,7 @@ class TestBasicOps:
recipe = make_recipe(quantization)
forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantize_backward),
te_ops.SwiGLU(),
te_ops.SwiGLU(glu_interleave_size=glu_interleave_size),
te_ops.Quantize(forward=quantize_forward, backward=False),
)
with te.autocast(enabled=quantized_compute, recipe=recipe):
......@@ -1739,10 +1783,19 @@ class TestBasicOps:
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
assert_close(y_test, y_ref, **tols)
assert_close_grads(x_test, x_ref, **tols)
def test_interleaved_swiglu(self):
"""SwiGLU with block interleaved input format"""
self.test_swiglu(
out_shape=(32, 192),
dtype=torch.float32,
quantization=None,
quantize_forward=False,
quantize_backward=False,
glu_interleave_size=32,
)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
......@@ -1752,6 +1805,7 @@ class TestBasicOps:
self,
*,
out_shape: Iterable[int] = (32, 32),
glu_interleave_size: Optional[int] = None,
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
......@@ -1760,7 +1814,7 @@ class TestBasicOps:
limit: float = 0.75,
alpha: float = 1.702,
):
# Test SwiGLU variant used in GPT OSS.
"""SwiGLU variant used in GPT-OSS"""
# Tensor dimensions
in_shape = list(out_shape)
in_shape[-1] *= 2
......@@ -1785,7 +1839,17 @@ class TestBasicOps:
)
# Plain PyTorch implementation
x_glu, x_linear = x_ref.chunk(2, dim=-1)
x = x_ref
if glu_interleave_size is not None:
x = x.reshape(
*in_shape[:-1],
in_shape[-1] // (2 * glu_interleave_size),
2,
glu_interleave_size,
)
x = x.transpose(-3, -2)
x = x.reshape(in_shape)
x_glu, x_linear = x.chunk(2, dim=-1)
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
......@@ -1797,7 +1861,11 @@ class TestBasicOps:
forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantize_backward),
te_ops.ClampedSwiGLU(limit=limit, alpha=alpha),
te_ops.ClampedSwiGLU(
limit=limit,
alpha=alpha,
glu_interleave_size=glu_interleave_size,
),
te_ops.Quantize(forward=quantize_forward, backward=False),
)
with te.autocast(enabled=quantized_compute, recipe=recipe):
......@@ -1813,10 +1881,19 @@ class TestBasicOps:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
assert_close(y_test, y_ref, **tols)
assert_close_grads(x_test, x_ref, **tols)
def test_interleaved_clamped_swiglu(self):
"""GPT-OSS SwiGLU with block interleaved input format"""
self.test_clamped_swiglu(
out_shape=(32, 192),
dtype=torch.float32,
quantization=None,
quantize_forward=False,
quantize_backward=False,
glu_interleave_size=32,
)
@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2)))
......@@ -1936,6 +2013,231 @@ class TestBasicOps:
abs(z_score) < 2.5758
), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})"
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("weight_requires_grad", (False, True))
def test_grouped_linear(
self,
*,
group_size: int = 4,
bias: bool,
weight_shape: tuple[int, int] = (128, 128),
split_alignment: int = 128,
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_compute: bool,
quantized_weight: bool,
input_requires_grad: bool,
weight_requires_grad: bool,
) -> None:
"""Grouped GEMM"""
# Split sizes
split_sizes = [split_alignment * i for i in range(group_size)]
random.shuffle(split_sizes)
split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device)
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = (split_sizes.sum().item(), in_features)
out_shape = (in_shape[0], out_features)
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if quantization is None and (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not specified")
if quantization is not None and not (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not used")
if quantization is not None and dtype not in (torch.bfloat16, torch.float16):
pytest.skip("Quantized group GEMM is only supported with BF16/FP16")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=input_requires_grad,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
ws_ref, ws_test = [], []
bs_ref, bs_test = [], []
for _ in range(group_size):
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=weight_requires_grad,
)
b_ref, b_test = None, None
if bias:
b_ref, b_test = make_reference_and_test_tensors(
out_features,
test_dtype=dtype,
test_device=device,
requires_grad=weight_requires_grad,
)
ws_ref.append(w_ref)
ws_test.append(w_test)
bs_ref.append(b_ref)
bs_test.append(b_test)
# Plain PyTorch implementation
xs_ref = torch.split(x_ref, split_sizes.tolist())
ys_ref = []
for x, w, b in zip(xs_ref, ws_ref, bs_ref):
ys_ref.append(torch.nn.functional.linear(x, w, bias=b))
y_ref = torch.cat(ys_ref)
if input_requires_grad or weight_requires_grad:
y_ref.backward(dy_ref)
# Construct fusible operation
recipe = make_recipe(quantization)
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.GroupedLinear(
group_size,
in_features,
out_features,
bias=bias,
device=device,
dtype=dtype,
)
with torch.no_grad():
for group_idx in range(group_size):
getattr(op, f"weight{group_idx}").copy_(ws_test[group_idx])
if bias:
getattr(op, f"bias{group_idx}").copy_(bs_test[group_idx])
del ws_test, bs_test
for param in op.parameters():
param.requires_grad_(requires_grad=weight_requires_grad)
# Forward and backward pass with op
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = op(x_test, split_sizes)
if input_requires_grad or weight_requires_grad:
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
if input_requires_grad:
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
else:
assert x_test.grad is None
for group_idx in range(group_size):
w_test = getattr(op, f"weight{group_idx}")
if weight_requires_grad:
dw_test = w_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dw_test, ws_ref[group_idx].grad, **tols)
else:
assert w_test.grad is None
if bias:
b_test = getattr(op, f"bias{group_idx}")
if weight_requires_grad:
db_test = b_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, bs_ref[group_idx].grad, **tols)
else:
assert b_test.grad is None
@pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128)))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("scales_requires_grad", (False, True))
def test_scaled_swiglu(
self,
*,
in_shape: Iterable[int],
glu_interleave_size: Optional[int] = None,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
input_requires_grad: bool,
scales_requires_grad: bool,
) -> None:
"""SwiGLU with post-scale"""
# Tensor dims
out_shape = list(in_shape)
out_shape[-1] //= 2
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=input_requires_grad,
)
scales_ref, scales_test = make_reference_and_test_tensors(
in_shape[:-1],
test_dtype=dtype,
test_device=device,
requires_grad=scales_requires_grad,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
x = x_ref
if glu_interleave_size is not None:
x = x.reshape(
-1,
in_shape[-1] // (2 * glu_interleave_size),
2,
glu_interleave_size,
)
x = x.transpose(1, 2)
x = x.reshape(in_shape)
x1, x2 = x.chunk(2, dim=-1)
y = torch.nn.functional.silu(x1) * x2
y_ref = scales_ref.unsqueeze(-1) * y
if input_requires_grad or scales_requires_grad:
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
y_test = op(x_test, scales_test)
if input_requires_grad or scales_requires_grad:
y_test.backward(dy_test)
# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
assert_close(y_test, y_ref, **tols)
assert_close_grads(x_test, x_ref, **tols)
assert_close_grads(scales_test, scales_ref, **tols)
def test_interleaved_scaled_swiglu(self):
"""SwiGLU with post-scale and block interleaved input format"""
self.test_scaled_swiglu(
in_shape=(32, 192),
glu_interleave_size=32,
input_requires_grad=True,
scales_requires_grad=True,
)
class TestFusedOps:
"""Tests for fused operations"""
......@@ -2342,13 +2644,13 @@ class TestFusedOps:
backward_ops = model._module_groups[0]._backward_ops
if with_quantization:
assert len(backward_ops) == 2
assert isinstance(backward_ops[0][0], BackwardActivationBias)
assert isinstance(backward_ops[1][0], te_ops.Quantize)
assert isinstance(backward_ops[0][0], te_ops.Quantize)
assert isinstance(backward_ops[1][0], BackwardActivationBias)
else:
assert len(backward_ops) == 3
assert isinstance(backward_ops[0][0], act_type)
assert isinstance(backward_ops[0][0], te_ops.Quantize)
assert isinstance(backward_ops[1][0], te_ops.Bias)
assert isinstance(backward_ops[2][0], te_ops.Quantize)
assert isinstance(backward_ops[2][0], act_type)
# Expected numerical error
tols = dtype_tols(dtype)
......@@ -2944,3 +3246,499 @@ class TestSequentialModules:
if bias:
torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols)
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("glu_interleave_size", (None, 32))
def test_grouped_mlp(
self,
*,
group_size: int = 4,
bias: bool,
hidden_size: int = 256,
dtype: torch.dtype,
quantization: Optional[str],
device: torch.device = "cuda",
split_alignment: int = 256,
glu_interleave_size: Optional[int],
) -> None:
"""GroupedLinear + ScaledSwiGLU + GroupedLinear"""
# Split sizes
split_sizes = [split_alignment * i for i in range(group_size)]
random.shuffle(split_sizes)
split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device)
# Make input shape
in_shape = (split_sizes.sum().item(), hidden_size)
out_shape = in_shape
# Skip invalid configurations
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
if with_quantization and dtype not in (torch.bfloat16, torch.float16):
pytest.skip("Quantized group GEMM is only supported with BF16/FP16")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
min=-0.25,
max=0.25,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
min=-0.25,
max=0.25,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
probs_ref, probs_test = make_reference_and_test_tensors(
(in_shape[0],),
test_dtype=dtype,
test_device=device,
)
fc1_ws_ref, fc1_ws_test = [], []
fc1_bs_ref, fc1_bs_test = [], []
fc2_ws_ref, fc2_ws_test = [], []
fc2_bs_ref, fc2_bs_test = [], []
for _ in range(group_size):
fc1_w_ref, fc1_w_test = make_reference_and_test_tensors(
(2 * hidden_size, hidden_size),
min=-0.25,
max=0.25,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
fc2_w_ref, fc2_w_test = make_reference_and_test_tensors(
(hidden_size, hidden_size),
min=-0.25,
max=0.25,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
fc1_b_ref, fc1_b_test = None, None
fc2_b_ref, fc2_b_test = None, None
if bias:
fc1_b_ref, fc1_b_test = make_reference_and_test_tensors(
(2 * hidden_size,),
min=-0.5,
max=0.5,
test_dtype=dtype,
test_device=device,
)
fc2_b_ref, fc2_b_test = make_reference_and_test_tensors(
(hidden_size,),
min=-0.5,
max=0.5,
test_dtype=dtype,
test_device=device,
)
fc1_ws_ref.append(fc1_w_ref)
fc1_bs_ref.append(fc1_b_ref)
fc1_ws_test.append(fc1_w_test)
fc1_bs_test.append(fc1_b_test)
fc2_ws_ref.append(fc2_w_ref)
fc2_bs_ref.append(fc2_b_ref)
fc2_ws_test.append(fc2_w_test)
fc2_bs_test.append(fc2_b_test)
# Reference implementation
xs = torch.split(x_ref, split_sizes.tolist())
probs = torch.split(probs_ref, split_sizes.tolist())
ys = []
for group_idx in range(group_size):
x = xs[group_idx]
x = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx])
if glu_interleave_size is not None:
x = x.reshape(
-1,
2 * hidden_size // (2 * glu_interleave_size),
2,
glu_interleave_size,
)
x = x.transpose(1, 2)
x = x.reshape(-1, 2 * hidden_size)
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
x = x * probs[group_idx].unsqueeze(-1)
x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx], bias=fc2_bs_ref[group_idx])
ys.append(x)
y_ref = torch.cat(ys)
y_ref.backward(dy_ref)
# Construct operations
recipe = make_recipe(quantization)
with te.quantized_model_init(enabled=with_quantization, recipe=recipe):
fc1 = te_ops.GroupedLinear(
group_size,
hidden_size,
2 * hidden_size,
bias=bias,
device=device,
dtype=dtype,
)
fc2 = te_ops.GroupedLinear(
group_size,
hidden_size,
hidden_size,
bias=bias,
device=device,
dtype=dtype,
)
module = te_ops.Sequential(
fc1,
te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size),
fc2,
)
# Copy weights
with torch.no_grad():
for group_idx in range(group_size):
getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx])
getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx])
if bias:
getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx])
getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx])
del fc1_ws_test, fc1_bs_test, fc2_ws_test, fc2_bs_test
# Fuse ops and perform forward and backward pass
with te.autocast(enabled=with_quantization, recipe=recipe):
y_test = module(x_test, split_sizes, probs_test, split_sizes)
y_test.backward(dy_test)
# Loose tols for sanity checking
tols = {"rtol": 0.125, "atol": 0.25}
if quantization == "nvfp4":
tols = {"rtol": 0.25, "atol": 0.5}
# Check values
assert_close(y_test, y_ref, **tols)
assert_close_grads(x_test, x_ref, **tols)
assert_close_grads(probs_test, probs_ref, **tols)
for group_idx in range(group_size):
assert_close_grads(getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols)
assert_close_grads(getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols)
assert_close_grads(getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols)
assert_close_grads(getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols)
class TestCustomOps:
"""Test with ops that are defined externally"""
def test_custom_basic_op(
self,
*,
shape: Iterable[int] = (7, 5),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
) -> None:
"""Custom basic op"""
class CustomScaleOp(te.ops.BasicOperation):
"""Custom op that applies a learnable scale"""
def __init__(self) -> None:
super().__init__()
self.scale: torch.nn.Parameter
scale = torch.ones((), dtype=dtype, device=device)
scale = torch.nn.Parameter(scale)
self.register_parameter("scale", scale)
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
ctx.save_for_backward(self.scale, input_)
return self.scale * input_
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> torch.Tensor:
(
scale,
input_,
) = ctx.saved_tensors
grad_scale = torch.inner(input_.reshape(-1), grad_output.reshape(-1))
grad_scale = grad_scale.reshape(())
grad_input = scale * grad_output
return grad_input, (grad_scale,)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = w_ref * x_ref
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = CustomScaleOp()
forward = te.ops.Sequential(te.ops.Identity(), op, te.ops.Identity())
with torch.no_grad():
op.scale.copy_(w_test)
del w_test
y_test = forward(x_test)
y_test.backward(dy_test)
# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.scale.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
def test_custom_forward_fused_op(
self,
*,
shape: Iterable[int] = (7, 11),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
"""Custom fused op in forward pass"""
class CustomForwardLinearSiLU(te.ops.FusedOperation):
"""Custom fused op for GEMM + SiLU"""
_enabled = True
def __init__(self, *, linear, silu) -> None:
super().__init__((linear, silu))
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
**unused,
) -> torch.Tensor:
weight = self.basic_ops[0].weight
dtype = weight.dtype
device = weight.device
# Perform compute on CPU, because why not?
x = input_.cpu()
w = weight.cpu()
y = torch.matmul(x, w.T)
z = torch.nn.functional.silu(y)
out = z.to(device=device)
# Save state for linear backward
linear_op_ctx = basic_op_ctxs[0]
linear_op_ctx.save_for_backward(input_, weight)
linear_op_ctx.with_quantized_compute = False
linear_op_ctx.input_quantizer = None
linear_op_ctx.weight_quantizer = None
linear_op_ctx.grad_output_quantizer = None
linear_op_ctx.grad_input_quantizer = None
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = True
linear_op_ctx.weight_requires_grad = True
# Save state for SiLU backward
silu_op_ctx = basic_op_ctxs[1]
silu_op_ctx.save_for_backward(y.to(device=device))
silu_op_ctx.dtype = dtype
silu_op_ctx.prev_op_grad_output_quantizer = None
return out, [(), ()]
@staticmethod
def fuse_ops(
ops: list[FusibleOperation],
**unused,
) -> list[FusibleOperation]:
"""Apply fusion the first time this function is called"""
if CustomForwardLinearSiLU._enabled:
CustomForwardLinearSiLU._enabled = False
op = CustomForwardLinearSiLU(linear=ops[0], silu=ops[1])
return [op] + ops[2:]
return ops
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(shape[-1], shape[-1]),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
y_ref = torch.nn.functional.silu(y_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operation
te.ops.register_forward_fusion(CustomForwardLinearSiLU.fuse_ops)
model = te.ops.Sequential(
te.ops.Linear(shape[-1], shape[-1], bias=False),
te.ops.SiLU(),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
del w_test
y_test = model(x_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
forward_ops = model._module_groups[0]._forward_ops
assert len(forward_ops) == 1
assert isinstance(forward_ops[0][0], CustomForwardLinearSiLU)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
def test_custom_backward_fused_op(
self,
*,
shape: Iterable[int] = (13, 5),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
"""Custom fused op in backward pass"""
class CustomBackwardLinearScale(te.ops.FusedOperation):
"""Custom fused op for backward linear + scale"""
_enabled: bool = True
def __init__(self, *, scale, linear) -> None:
super().__init__((scale, linear))
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
**unused,
) -> torch.Tensor:
# Load state from linear forward
linear_op_ctx = basic_op_ctxs[1]
x, w = linear_op_ctx.saved_tensors
dtype = linear_op_ctx.dtype
device = w.device
# Perform compute in FP64 and apply scale before dgrad
# GEMM instead of after
scale = self.basic_ops[0].scale
dy = grad_output.double()
x = x.double()
w = w.double()
dx = torch.matmul(dy, scale * w)
dw = torch.matmul(dy.T, x)
dx = dx.to(dtype=dtype)
dw = dw.to(dtype=dtype)
return dx, [(), (dw,)], [(), ()]
@staticmethod
def fuse_ops(
ops: list[FusibleOperation],
**unused,
) -> list[FusibleOperation]:
"""Apply fusion the first time this function is called"""
if CustomBackwardLinearScale._enabled:
CustomBackwardLinearScale._enabled = False
op = CustomBackwardLinearScale(scale=ops[0], linear=ops[1])
return [op] + ops[2:]
return ops
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(shape[-1], shape[-1]),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
scale = 1.234
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(scale * x_ref, w_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operation
te.ops.register_backward_fusion(CustomBackwardLinearScale.fuse_ops, prepend=True)
model = te.ops.Sequential(
te.ops.ConstantScale(scale),
te.ops.Linear(shape[-1], shape[-1], bias=False),
)
with torch.no_grad():
model[1].weight.copy_(w_test)
del w_test
y_test = model(x_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(backward_ops[0][0], CustomBackwardLinearScale)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for GroupedTensor class"""
from typing import List, Tuple
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
from transformer_engine.pytorch import (
Quantizer,
Float8Quantizer,
Float8CurrentScalingQuantizer,
Float8BlockQuantizer,
MXFP8Quantizer,
NVFP4Quantizer,
)
from transformer_engine.pytorch.constants import TE_DType_To_Torch
import transformer_engine_torch as tex
# Check available recipes
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
return_reason=True
)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
_quantization_params = [
pytest.param(
"fp8_delayed_scaling",
marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8),
),
pytest.param(
"fp8_current_scaling",
marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8),
),
pytest.param(
"fp8_blockwise",
marks=pytest.mark.skipif(
not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling
),
),
pytest.param(
"mxfp8",
marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8),
),
pytest.param(
"nvfp4",
marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4),
),
]
def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer:
"""Create quantizers for given quantization scheme"""
if quantization == "fp8_delayed_scaling":
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device="cuda"),
amax=torch.zeros(1, dtype=torch.float32, device="cuda"),
fp8_dtype=tex.DType.kFloat8E4M3,
)
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device="cuda",
)
quantizer.set_usage(rowwise=True, columnwise=False)
elif quantization == "fp8_blockwise":
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=False,
force_pow_2_scales=True,
amax_epsilon=0.0,
block_scaling_dim=1,
)
elif quantization == "mxfp8":
quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
elif quantization == "nvfp4":
quantizer = NVFP4Quantizer(
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)
else:
raise ValueError(f"Unknown quantization scheme: {quantization}")
quantizer.internal = False
return quantizer
def _get_rowwise_data_tensor(qtensor, quantization: str) -> torch.Tensor:
if quantization in ("fp8_delayed_scaling", "fp8_current_scaling"):
return qtensor._data
if quantization in ("fp8_blockwise", "mxfp8", "nvfp4"):
return qtensor._rowwise_data
raise ValueError(f"Unknown quantization scheme: {quantization}")
def _rowwise_offset_bytes(numel: int, quantization: str) -> int:
if quantization == "nvfp4":
return numel // 2
return numel
class TestGroupedTensor:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def test_basic_construction_all_same_shape(self) -> None:
"""Test GroupedTensor construction with all tensors having same shape"""
num_tensors = 4
shape = [(256, 512) for _ in range(num_tensors)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
assert grouped_tensor.num_tensors == num_tensors
assert grouped_tensor.all_same_shape()
assert grouped_tensor.all_same_first_dim()
assert grouped_tensor.all_same_last_dim()
assert grouped_tensor.logical_shape == (num_tensors * 256, 512)
assert grouped_tensor.get_common_first_dim() == 256
assert grouped_tensor.get_common_last_dim() == 512
assert grouped_tensor.has_data()
def test_basic_construction_varying_first_dim(self) -> None:
"""Test GroupedTensor construction with varying first dimension"""
num_tensors = 3
shape = [(128, 512), (256, 512), (384, 512)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
assert grouped_tensor.num_tensors == num_tensors
assert not grouped_tensor.all_same_shape()
assert not grouped_tensor.all_same_first_dim()
assert grouped_tensor.all_same_last_dim()
assert grouped_tensor.get_common_last_dim() == shape[0][1]
assert grouped_tensor.logical_shape == (
sum(v for v, _ in shape),
shape[0][1],
) # sum of first dims
def test_split_into_quantized_tensors_no_quantization(self) -> None:
"""Test split_into_quantized_tensors for unquantized tensors"""
num_tensors = 3
shape = [(256, 512) for _ in range(num_tensors)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
# Get the original data pointer
original_data_ptr = grouped_tensor.data.data_ptr()
# Split into tensors
tensors = grouped_tensor.split_into_quantized_tensors()
assert len(tensors) == num_tensors
# Verify each tensor has correct shape and shares storage
for i, tensor in enumerate(tensors):
assert tensor.shape == shape[i]
assert isinstance(tensor, torch.Tensor)
assert not hasattr(tensor, "_data") # Not a quantized tensor
# Verify data pointer is within the original grouped tensor storage
# The tensor should be a view of the original data
assert tensor.data_ptr() >= original_data_ptr
# Calculate expected offset
expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size()
assert tensor.data_ptr() == original_data_ptr + expected_offset
@pytest.mark.parametrize("quantization", _quantization_params)
def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None:
"""Test split_into_quantized_tensors for quantized tensors"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizers = make_quantizer(quantization, num_tensors, shape)
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizers,
device="cuda",
)
# Get the original data pointer
original_data_ptr = grouped_tensor.data.data_ptr()
# Split into tensors
tensors = grouped_tensor.split_into_quantized_tensors()
assert len(tensors) == num_tensors
# Verify each tensor shares storage with the grouped tensor
for i, tensor in enumerate(tensors):
rowwise_data = _get_rowwise_data_tensor(tensor, quantization)
assert rowwise_data is not None
assert rowwise_data.data_ptr() >= original_data_ptr
numel = shape[i][0] * shape[i][1]
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
def test_split_varying_shapes(self) -> None:
"""Test split_into_quantized_tensors with varying shapes"""
num_tensors = 3
shape = [(128, 512), (256, 512), (384, 512)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
original_data_ptr = grouped_tensor.data.data_ptr()
tensors = grouped_tensor.split_into_quantized_tensors()
assert len(tensors) == num_tensors
# Verify shapes and storage
cumulative_offset = 0
for i, tensor in enumerate(tensors):
assert tensor.shape == shape[i]
expected_offset = cumulative_offset * tensor.element_size()
assert tensor.data_ptr() == original_data_ptr + expected_offset
cumulative_offset += shape[i][0] * shape[i][1]
@pytest.mark.parametrize("quantization", _quantization_params)
def test_quantize_inplace(self, quantization: str) -> None:
"""Test that quantize is done in-place for all recipes"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizers = make_quantizer(quantization, num_tensors, shape)
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizers,
device="cuda",
)
# Get original data pointers before quantization
original_data_ptr = grouped_tensor.data.data_ptr()
original_scale_inv_ptr = grouped_tensor.scale_inv.data_ptr()
original_scale_ptr = (
grouped_tensor.scale.data_ptr() if grouped_tensor.scale is not None else None
)
# Create input tensors
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
# Quantize in place
quantized_tensors = grouped_tensor.quantize(input_tensors)
# Verify data pointers haven't changed (in-place operation)
assert grouped_tensor.data.data_ptr() == original_data_ptr
assert grouped_tensor.scale_inv.data_ptr() == original_scale_inv_ptr
if original_scale_ptr is not None:
assert grouped_tensor.scale.data_ptr() == original_scale_ptr
# Verify returned tensors point to the same storage
for i, qtensor in enumerate(quantized_tensors):
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
numel = shape[i][0] * shape[i][1]
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
@pytest.mark.parametrize("quantization", _quantization_params)
def test_quantize_varying_shapes(self, quantization: str) -> None:
"""Test quantize with varying shapes"""
num_tensors = 3
shape = [(256, 512), (512, 512), (768, 512)]
quantizers = make_quantizer(quantization, num_tensors, shape)
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizers,
device="cuda",
)
# Get original data pointers
original_data_ptr = grouped_tensor.data.data_ptr()
# Create input tensors with varying shapes
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
# Quantize in place
quantized_tensors = grouped_tensor.quantize(input_tensors)
# Verify data pointer hasn't changed
assert grouped_tensor.data.data_ptr() == original_data_ptr
# Verify each tensor points to correct location
cumulative_numel = 0
for qtensor, tensor_shape in zip(quantized_tensors, shape):
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
cumulative_numel += tensor_shape[0] * tensor_shape[1]
@pytest.mark.parametrize("quantization", _quantization_params)
def test_static_quantize_method(self, quantization: str) -> None:
"""Test the static quantize method"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizers = make_quantizer(quantization, num_tensors, shape)
# Create input tensors
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
# Use static quantize method
grouped_tensor = GroupedTensor.create_and_quantize(
tensors=input_tensors,
quantizer=quantizers,
device="cuda",
)
# Verify the grouped tensor was created correctly
assert grouped_tensor.num_tensors == num_tensors
assert grouped_tensor.has_data()
# Verify quantized_tensors were created and point to same storage
assert grouped_tensor.quantized_tensors is not None
assert len(grouped_tensor.quantized_tensors) == num_tensors
original_data_ptr = grouped_tensor.data.data_ptr()
for i, qtensor in enumerate(grouped_tensor.quantized_tensors):
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
numel = shape[i][0] * shape[i][1]
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
def test_clear(self) -> None:
"""Test clear method"""
num_tensors = 3
shape = [(256, 512) for _ in range(num_tensors)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
assert grouped_tensor.has_data()
assert grouped_tensor.num_tensors == num_tensors
grouped_tensor.clear()
assert not grouped_tensor.has_data()
assert grouped_tensor.num_tensors == 0
assert grouped_tensor.data is None
assert grouped_tensor.logical_shape == (0, 0)
......@@ -94,6 +94,7 @@ all_boolean = [True, False]
all_activations = [
"gelu",
"geglu",
"glu",
"qgelu",
"qgeglu",
"relu",
......@@ -484,6 +485,7 @@ class TorchGroupedLinearWithPadding(nn.Module):
_supported_act = {
"gelu": nn.GELU(approximate="tanh"),
"geglu": nn.GELU(approximate="tanh"),
"glu": nn.Sigmoid(),
"qgelu": TorchQuickGELU(),
"qgeglu": TorchQuickGELU(),
"relu": nn.ReLU(),
......
......@@ -745,6 +745,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation):
_test_export_layernorm_mlp(activation=activation)
# Quantization recipes with fp8_dpa=True for attention emulation export test
dpa_quantization_recipes = [None] # None = no quantization
if fp8_available:
dpa_quantization_recipes.append(recipe.DelayedScaling(fp8_dpa=True))
dpa_quantization_recipes.append(recipe.Float8CurrentScaling(fp8_dpa=True))
@pytest.mark.parametrize("fp8_recipe", dpa_quantization_recipes)
@pytest.mark.parametrize(
"precision, use_mask, attn_mask_type",
[
......@@ -762,6 +770,7 @@ def test_export_core_attention(
precision: torch.dtype,
use_mask: bool,
attn_mask_type: str,
fp8_recipe: recipe.Recipe,
):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
......@@ -783,22 +792,25 @@ def test_export_core_attention(
mask_str = get_attn_mask_str(use_mask, attn_mask_type)
high_prec_str = dtype2str(precision)
fname = f"te.core_attention{mask_str}{high_prec_str}.onnx"
fp8_str = "_fp8_dpa" if fp8_recipe is not None else ""
fname = f"te.core_attention{fp8_str}{mask_str}{high_prec_str}.onnx"
is_fp8 = fp8_recipe is not None
model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
attention_dropout=0.5,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
).to(device="cuda")
do_export(model, inp, fname, input_names=input_names, fp8_recipe=None)
te_outputs = te_infer(model, inp, is_fp8=False, fp8_recipe=None)
do_export(model, inp, fname, input_names=input_names, fp8_recipe=fp8_recipe)
te_outputs = te_infer(model, inp, is_fp8=is_fp8, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision in (torch.bfloat16,):
return
atol = 5e-1 if is_fp8 else 1e-2
validate_result(
fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs
fname, inp, model, is_fp8=True, atol=atol, input_names=input_names, te_outputs=te_outputs
)
......
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
from typing import Optional
from typing import Optional, List
import torch
import pytest
......@@ -114,6 +114,7 @@ batch_sizes_with_zero = [0, 1, 2]
all_activations = [
"gelu",
"geglu",
"glu",
"qgelu",
"qgeglu",
"relu",
......@@ -138,6 +139,117 @@ def reset_global_fp8_state():
FP8GlobalStateManager.reset()
def check_grouped_tensor_pointers_helper(tensors, num_elems_in_byte=1, tensor_name="tensor"):
"""
Verify that tensors are stored in contiguous memory.
Args:
tensors: List or iterable of tensors to check
num_elems_in_byte: Number of elements packed per byte (1 for normal, 2 for NVFP4)
tensor_name: Name to use in error messages
"""
tensor_list = list(tensors)
if len(tensor_list) < 2:
return # Nothing to check
for i in range(1, len(tensor_list)):
prev_tensor = tensor_list[i - 1]
curr_tensor = tensor_list[i]
# Calculate expected offset based on previous tensor size
prev_numel = prev_tensor.numel()
expected_offset = (prev_numel // num_elems_in_byte) * prev_tensor.element_size()
# Verify current tensor's data pointer is correctly offset
expected_ptr = prev_tensor.data_ptr() + expected_offset
actual_ptr = curr_tensor.data_ptr()
assert (
actual_ptr == expected_ptr
), f"{tensor_name} {i} data pointer mismatch: expected {expected_ptr}, got {actual_ptr}"
def check_grouped_tensor_pointers(
weights: List[torch.Tensor], fp8_recipe: Optional[recipe.Recipe] = None
):
"""
Verify that the pointers of the weights are in contiguous memory for GroupedTensor.
TODO(ksivaman): This check can be made way more efficient but for now leaving the brute force approach.
"""
num_elems_in_a_data_byte = 1 if fp8_recipe is None else 2 if fp8_recipe.nvfp4() else 1
# Check data.
if hasattr(weights[0], "_data") and weights[0]._data is not None:
data_tensors = [w._data for w in weights]
check_grouped_tensor_pointers_helper(data_tensors, num_elems_in_byte=1, tensor_name="data")
# Check transpose.
if hasattr(weights[0], "_transpose") and weights[0]._transpose is not None:
transpose_tensors = [w._transpose for w in weights]
check_grouped_tensor_pointers_helper(
transpose_tensors, num_elems_in_byte=1, tensor_name="transpose"
)
# Check scale_inv.
if hasattr(weights[0], "_scale_inv") and weights[0]._scale_inv is not None:
scale_inv_tensors = [w._scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
scale_inv_tensors, num_elems_in_byte=1, tensor_name="scale_inv"
)
# Check rowwise scale_inv.
if hasattr(weights[0], "_rowwise_scale_inv") and weights[0]._rowwise_scale_inv is not None:
scale_inv_tensors = [w._rowwise_scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
scale_inv_tensors, num_elems_in_byte=1, tensor_name="rowwise_scale_inv"
)
# Check columnwise scale_inv.
if (
hasattr(weights[0], "_columnwise_scale_inv")
and weights[0]._columnwise_scale_inv is not None
):
columnwise_scale_inv_tensors = [w._columnwise_scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_scale_inv_tensors,
num_elems_in_byte=1,
tensor_name="columnwise scale_inv",
)
# Check rowwise amax.
if hasattr(weights[0], "_rowwise_amax") and weights[0]._rowwise_amax is not None:
rowwise_amax_tensors = [w._rowwise_amax for w in weights]
check_grouped_tensor_pointers_helper(
rowwise_amax_tensors, num_elems_in_byte=1, tensor_name="rowwise amax"
)
# Check columnwise amax.
if hasattr(weights[0], "_columnwise_amax") and weights[0]._columnwise_amax is not None:
columnwise_amax_tensors = [w._columnwise_amax for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_amax_tensors, num_elems_in_byte=1, tensor_name="columnwise amax"
)
# Check rowwise data.
if hasattr(weights[0], "_rowwise_data") and weights[0]._rowwise_data is not None:
rowwise_data_tensors = [w._rowwise_data for w in weights]
check_grouped_tensor_pointers_helper(
rowwise_data_tensors,
num_elems_in_byte=num_elems_in_a_data_byte,
tensor_name="rowwise data",
)
# Check columnwise data.
if hasattr(weights[0], "_columnwise_data") and weights[0]._columnwise_data is not None:
columnwise_data_tensors = [w._columnwise_data for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_data_tensors,
num_elems_in_byte=num_elems_in_a_data_byte,
tensor_name="columnwise data",
)
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size),
......@@ -486,10 +598,19 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("single_param", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
dtype,
bs,
model,
fp8_recipe,
fp8_model_params,
use_bias,
single_param,
num_gemms,
empty_split,
):
if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
pytest.skip("FP8 model parameters are not supported in debug mode.")
......@@ -499,6 +620,9 @@ def test_sanity_grouped_linear(
bs = bs * 16
num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
if single_param:
os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] = "1"
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
......@@ -508,9 +632,19 @@ def test_sanity_grouped_linear(
use_fp8 = fp8_recipe is not None
with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
te_grouped_linear = GroupedLinear(
num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
num_gemms,
config.hidden_size,
ffn_hidden_size,
bias=use_bias,
params_dtype=dtype,
).cuda()
# Verify that weights are stored in contiguous GroupedTensor storage.
weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)]
if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()):
if single_param:
check_grouped_tensor_pointers(weights, fp8_recipe)
inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
......@@ -528,6 +662,9 @@ def test_sanity_grouped_linear(
loss.backward()
assert out.shape == (num_tokens, ffn_hidden_size)
if single_param:
del os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"]
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
......@@ -1005,7 +1142,13 @@ def test_replace_raw_data_for_float8tensor():
random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)
attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
attrs_to_check = [
"_quantizer",
"_fp8_dtype",
"_scale_inv",
"_transpose",
"_transpose_invalid",
]
attrs = {}
for attr in attrs_to_check:
attrs[attr] = getattr(fp8_tensor, attr)
......
......@@ -15,7 +15,7 @@ import torch
import transformer_engine
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import InferenceParams
from transformer_engine.pytorch import InferenceParams, QuantizedTensor
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
get_attention_backend,
......@@ -353,7 +353,7 @@ def get_available_attention_backends(
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
if AttentionLogging._is_logging_setup is False:
AttentionLogging.setup_logging()
with logging_context(highest_level=AttentionLogging._log_level):
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
......@@ -361,3 +361,48 @@ def get_available_attention_backends(
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, flash_attention_backend, fused_attn_backends
@torch.no_grad
def assert_close(
actual: Optional[torch.Tensor],
expected: Optional[torch.Tensor],
*,
check_device: bool = False,
check_dtype: bool = False,
check_layout: bool = False,
**kwargs,
) -> None:
"""Assert that two tensors are close.
This function is a wrapper around torch.testing.assert_close. It
changes the defaults for device and dtype checks (useful when the
reference implementation is computed in high precision on CPU) and
it can handle quantized tensors.
"""
if isinstance(actual, QuantizedTensor):
actual = actual.dequantize()
if isinstance(expected, QuantizedTensor):
expected = expected.dequantize()
torch.testing.assert_close(
actual,
expected,
check_device=check_device,
check_dtype=check_dtype,
check_layout=check_layout,
**kwargs,
)
def assert_close_grads(
actual: Optional[torch.Tensor],
expected: Optional[torch.Tensor],
**kwargs,
) -> None:
"""Assert that two tensors have close gradients."""
if actual is None and expected is None:
return
assert actual is not None
assert expected is not None
assert_close(actual.grad, expected.grad, **kwargs)
......@@ -202,6 +202,7 @@ if(USE_CUDA)
fused_attn/fused_attn_fp8.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/cublaslt_grouped_gemm.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
......@@ -225,15 +226,18 @@ if(USE_CUDA)
list(APPEND transformer_engine_cuda_arch_specific_sources
activation/gelu.cu
activation/glu.cu
activation/relu.cu
activation/swiglu.cu
cast/cast.cu
gemm/cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/graph_safe_group_hadamard_transform.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu
......@@ -357,6 +361,7 @@ else()
fused_attn/kv_cache.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/cublaslt_grouped_gemm.cu
gemm/hipblas_gemm.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
......@@ -381,6 +386,7 @@ else()
list(APPEND transformer_engine_cuda_arch_specific_sources
activation/gelu.cu
activation/glu.cu
activation/relu.cu
activation/swiglu.cu
cast/cast.cu
......@@ -476,20 +482,18 @@ endif()
option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF)
if (NVTE_WITH_CUBLASMP)
target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP)
target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include)
target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include)
find_library(CUBLASMP_LIB
NAMES cublasmp libcublasmp
PATHS ${CUBLASMP_DIR}
PATH_SUFFIXES lib
REQUIRED)
find_library(NVSHMEM_HOST_LIB
NAMES nvshmem_host libnvshmem_host.so.3
PATHS ${NVSHMEM_DIR}
find_library(NCCL_LIB
NAMES nccl libnccl
PATH_SUFFIXES lib
REQUIRED)
target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB})
target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB} ${CUBLASMP_LIB})
message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}")
message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}")
endif()
if (USE_CUDA)
......@@ -561,6 +565,7 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
list(APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/glu.cu
activation/relu.cu
activation/swiglu.cu)
endif()
......
......@@ -246,11 +246,13 @@ def _nvidia_cudart_include_dir() -> str:
return ""
# Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia"
# above doesn't through. However, they don't set "__file__" attribute.
if nvidia.__file__ is None:
return ""
# above doesn't throw. However, they don't set "__file__" attribute.
if nvidia.__file__ is not None:
nvidia_root = Path(nvidia.__file__).parent
else:
nvidia_root = Path(nvidia.__path__[0]) # namespace package
include_dir = Path(nvidia.__file__).parent / "cuda_runtime"
include_dir = nvidia_root / "cuda_runtime"
return str(include_dir) if include_dir.exists() else ""
......
......@@ -13,6 +13,14 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
}
void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_gelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, gelu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgelu);
......@@ -20,6 +28,20 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dgelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
......@@ -33,6 +55,20 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
......@@ -54,6 +90,15 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
}
void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_qgelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, qgelu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgelu);
......@@ -61,6 +106,20 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dqgelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
......@@ -74,6 +133,20 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../util/math.h"
#include "./activation_template.h"
void nvte_glu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_glu);
using namespace transformer_engine;
Empty e = {};
gated_act_fn<fp32, Empty, sigmoid<fp32, fp32>>(input, output, e, stream);
}
void nvte_dglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dglu);
using namespace transformer_engine;
Empty e = {};
dgated_act_fn<fp32, Empty, sigmoid<fp32, fp32>, dsigmoid<fp32, fp32>>(grad, input, output, e,
stream);
}
......@@ -13,6 +13,14 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream);
}
void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_relu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, relu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_drelu);
......@@ -20,6 +28,20 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, drelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_drelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, drelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
......@@ -33,6 +55,20 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_drelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, drelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine;
......@@ -54,6 +90,15 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream);
}
void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_srelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, srelu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsrelu);
......@@ -61,6 +106,20 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dsrelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsrelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
......@@ -74,6 +133,20 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dsrelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsrelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine;
......
......@@ -13,6 +13,14 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream);
}
void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_silu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, silu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsilu);
......@@ -20,6 +28,20 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, dsilu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dsilu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
......@@ -33,6 +55,20 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dsilu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine;
......
......@@ -28,6 +28,15 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
}
void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize);
using namespace transformer_engine;
constexpr bool IS_ACT = false;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
}
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_noop);
......@@ -62,6 +71,19 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = false;
constexpr const NVTEGroupedTensor activation_input = nullptr;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, nullptr>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dequantize);
using namespace transformer_engine;
......
......@@ -37,6 +37,12 @@ inline bool dimensions_supported_by_TMA(const Tensor *const t) {
return cols % alignment_requirement == 0;
}
__device__ __forceinline__ unsigned char *align_smem_ptr_per_TMA_requirements(unsigned char *p) {
size_t addr = reinterpret_cast<size_t>(p);
addr = (addr + TMA_SHMEM_ALIGNMENT - 1) & ~(TMA_SHMEM_ALIGNMENT - 1);
return reinterpret_cast<unsigned char *>(addr);
}
namespace kernel {
constexpr size_t THREADS_PER_BLOCK = 256;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment