Commit 44740c6c authored by yuguo's avatar yuguo
Browse files

Merge commit '7a9a0825' of...

Merge commit '7a9a0825' of https://github.com/NVIDIA/TransformerEngine
parents 8113d9e0 7a9a0825
...@@ -106,7 +106,7 @@ all_normalizations = ["LayerNorm", "RMSNorm"] ...@@ -106,7 +106,7 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
mask_types = ["causal", "no_mask"] mask_types = ["causal", "no_mask"]
NVTE_TEST_NVINSPECT_ENABLED = os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False) NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))
if NVTE_TEST_NVINSPECT_ENABLED: if NVTE_TEST_NVINSPECT_ENABLED:
# The numerics of all the layers should work the same, # The numerics of all the layers should work the same,
...@@ -1059,8 +1059,11 @@ def test_mha_accuracy(dtype, bs, model, mask_type): ...@@ -1059,8 +1059,11 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False): def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False, recipe=None):
reset_rng_states() reset_rng_states()
fp8 = recipe is not None
if fp8:
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size), (config.seq_len, bs, config.hidden_size),
...@@ -1070,9 +1073,10 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False) ...@@ -1070,9 +1073,10 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False)
) )
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
out = block(inp_hidden_states) with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if isinstance(out, (List, Tuple)): out = block(inp_hidden_states)
out = out[0] if isinstance(out, (List, Tuple)):
out = out[0]
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
if delay_wgrad_compute: if delay_wgrad_compute:
...@@ -1268,6 +1272,64 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ ...@@ -1268,6 +1272,64 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_
torch.testing.assert_close(o, o_ref, rtol=0, atol=0) torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
def test_linear_accuracy_save_original_input(dtype, model, recipe):
bs = 1
fuse_wgrad_accumulation = True
fp8_model_params = False
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
te_linear_ref = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
save_original_input=False,
).eval()
te_linear = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
save_original_input=True,
).eval()
# Share params
with torch.no_grad():
te_linear_ref.weight = Parameter(te_linear.weight.clone())
if fuse_wgrad_accumulation:
weight = getattr(te_linear, f"weight")
weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
te_linear_ref.weight.main_grad = weight.main_grad.clone()
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe)
te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
...@@ -1768,6 +1830,111 @@ def test_grouped_linear_accuracy( ...@@ -1768,6 +1830,111 @@ def test_grouped_linear_accuracy(
device="cuda", device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation, fuse_wgrad_accumulation=fuse_wgrad_accumulation,
delay_wgrad_compute=delay_wgrad_compute, delay_wgrad_compute=delay_wgrad_compute,
save_original_input=False,
).eval()
sequential_linear = torch.nn.ModuleList(
[
Linear(
config.hidden_size,
4 * config.hidden_size,
bias=bias,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
for _ in range(num_gemms)
]
)
# Share params
with torch.no_grad():
for i in range(num_gemms):
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
if bias:
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
if fuse_wgrad_accumulation:
weight_i = getattr(grouped_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
outputs_ref = _test_grouped_linear_accuracy(
sequential_linear,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
)
outputs = _test_grouped_linear_accuracy(
grouped_linear,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("num_gemms", [3])
@pytest.mark.parametrize("bs", [1])
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
@pytest.mark.parametrize("fp8_model_params", [False])
@pytest.mark.parametrize("fuse_wgrad_accumulation", [True])
@pytest.mark.parametrize("bias", [False])
@pytest.mark.parametrize("delay_wgrad_compute", [True])
def test_grouped_linear_accuracy_save_original_input(
dtype,
num_gemms,
bs,
model,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
bias,
delay_wgrad_compute,
parallel_mode=None,
):
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=bias,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
delay_wgrad_compute=delay_wgrad_compute,
save_original_input=True,
).eval() ).eval()
sequential_linear = torch.nn.ModuleList( sequential_linear = torch.nn.ModuleList(
[ [
...@@ -1948,7 +2115,89 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r ...@@ -1948,7 +2115,89 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_padding_grouped_linear_accuracy( def test_padding_grouped_linear_accuracy(
dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None dtype,
num_gemms,
bs,
model,
fp8,
recipe,
fp8_model_params,
parallel_mode=None,
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
fp8=fp8,
).eval()
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
ref_grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
save_original_input=False,
).eval()
# Share params
with torch.no_grad():
inner_grouped_linear = grouped_linear.linear_fn
for i in range(num_gemms):
setattr(
ref_grouped_linear,
f"weight{i}",
Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
)
outputs = _test_padding_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
)
outputs_ref = _test_padding_grouped_linear_accuracy(
ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3])
@pytest.mark.parametrize("bs", [1])
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", [True])
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", [False])
def test_padding_grouped_linear_accuracy_save_original_input(
dtype,
num_gemms,
bs,
model,
fp8,
recipe,
fp8_model_params,
parallel_mode=None,
): ):
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -1958,6 +2207,8 @@ def test_padding_grouped_linear_accuracy( ...@@ -1958,6 +2207,8 @@ def test_padding_grouped_linear_accuracy(
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available: if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling) pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model] config = model_configs[model]
if config.seq_len % 16 != 0 and fp8: if config.seq_len % 16 != 0 and fp8:
...@@ -1983,6 +2234,7 @@ def test_padding_grouped_linear_accuracy( ...@@ -1983,6 +2234,7 @@ def test_padding_grouped_linear_accuracy(
params_dtype=dtype, params_dtype=dtype,
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
device="cuda", device="cuda",
save_original_input=True,
).eval() ).eval()
# Share params # Share params
...@@ -2334,9 +2586,9 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -2334,9 +2586,9 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
if ( if (
backend == "FusedAttention" backend == "FusedAttention"
and get_device_compute_capability() == (8, 9) and get_device_compute_capability() == (8, 9)
and get_cudnn_version() < (9, 11, 0) and get_cudnn_version() < (9, 12, 0)
): ):
pytest.skip("Skip KV cache for sm89 and cuDNN < 9.11") pytest.skip("Skip KV cache for sm89 and cuDNN < 9.12")
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
This file contains tests for exporting TransformerEngine models to ONNX.
The purpose of these tests is validation that TE models are converted to their correct ONNX
representation. Toward this end, each test captures the output of a TE module forward pass,
converts the TE module to ONNX, and uses ONNX Runtime (ORT) to execute the ONNX graph and
validate the output against TE's output.
Until FP8 is introduced to the ONNX standard, FP8 QuantizeLinear/DequantizeLinear is implemented
using custom ORT operations.
To run many repetitive tests use pytest-loop:
$ python3 -m pip install pytest-loop
$ pytest --loop 1000 tests/pytorch/test_onnx_export.py::test_export_layernorm
For reproducibility use: torch.manual_seed(0)
"""
import os
import tempfile
import pytest
import warnings
import numpy as np
import onnxruntime as ort
import torch
import random
from torch import nn as nn
from typing import Optional, Union, Tuple, List
from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import get_default_init_method
# Global test configuration knobs.
# Enable this to serialize test inputs and outputs to file (as a Polygraphy RunResults instance).
SAVE_TEST_IO = bool(int(os.getenv("NVTE_ONNX_EXPORT_SAVE_TEST_IO", "0")))
if SAVE_TEST_IO:
from polygraphy.json import save_json
from polygraphy.comparator import RunResults
# The directory where generated ONNX test models are stored.
NVTE_TEST_ARTIFACTS_DIR = os.environ.get("NVTE_TEST_ARTIFACTS_DIR")
NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(
tempfile.gettempdir(), "./gen_onnx_models"
)
# The directory where this file is stored.
TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
skip_MXFP8 = pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
fp8_recipes = [
None,
recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(),
]
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
@onnx_op(
op_type="trt::TRT_FP8QuantizeLinear",
domain="trt",
inputs=[
PyCustomOpDef.dt_float,
PyCustomOpDef.dt_float,
],
outputs=[PyCustomOpDef.dt_uint8],
)
def trt_fp8_quantize(t, scale):
"""FP8 quantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale).cuda(),
amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
return q(x)._data.cpu().numpy()
@onnx_op(
op_type="trt::TRT_FP8DequantizeLinear",
domain="trt",
inputs=[
PyCustomOpDef.dt_uint8,
PyCustomOpDef.dt_float,
],
outputs=[PyCustomOpDef.dt_float],
)
def trt_fp8_dequantize(t, scale):
"""FP8 dequantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale).cuda(),
amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
quantizer_tensor = q.create_tensor_from_data(x, fake_dtype=torch.float32)
return quantizer_tensor.dequantize().cpu().numpy()
@onnx_op(
op_type="trt::TRT_MXFP8QuantizeLinear",
domain="trt",
inputs=[
PyCustomOpDef.dt_float,
],
outputs=[PyCustomOpDef.dt_uint8, PyCustomOpDef.dt_uint8],
)
def trt_mxfp8_quantize(t):
"""MXFP8 quantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda()
q = te.tensor.mxfp8_tensor.MXFP8Quantizer(tex.DType.kFloat8E4M3)
return q(x)._rowwise_data.cpu().numpy(), q(x)._rowwise_scale_inv.cpu().numpy()
@onnx_op(
op_type="trt::TRT_MXFP8DequantizeLinear",
domain="trt",
inputs=[
PyCustomOpDef.dt_uint8,
PyCustomOpDef.dt_uint8,
],
outputs=[PyCustomOpDef.dt_float],
)
def trt_mxfp8_dequantize(t, scale_inv):
"""MXFP8 dequantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda()
scale_inv_tensor = torch.from_numpy(scale_inv).cuda()
q = te.tensor.mxfp8_tensor.MXFP8Quantizer(tex.DType.kFloat8E4M3)
quantizer_tensor = q.create_tensor_from_data(x, scale_inv_tensor, fake_dtype=torch.float32)
return quantizer_tensor.dequantize().cpu().numpy()
@pytest.fixture()
def seed_default_rng():
"""Reseed the PRNG for test reproducibility"""
torch.manual_seed(1234)
@pytest.fixture()
def set_max_seq_len(max_seq_len=128):
"""Set the maximum sequence length that can be used for attention masking"""
os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{max_seq_len}"
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
FP8GlobalStateManager.reset()
def do_export(
model: torch.nn.Module,
inp: torch.Tensor,
fname: str,
fp8_recipe: recipe.Recipe,
input_names: List[str] = None,
output_names: List[str] = None,
dynamic_shapes: List[str] = None,
):
"""Export to ONNX"""
input_names = input_names or ["input"]
output_names = output_names or ["output"]
with torch.inference_mode(), te.fp8_autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe
), warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*")
model.cuda().eval()
os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True)
fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
inps = inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,)
assert len(inps) == len(input_names)
inds_to_del = [i for i in range(len(inps)) if inps[i] is None]
input_names = [input_names[i] for i in range(len(inps)) if i not in inds_to_del]
model(*inps) # warm-up run
with te.export.onnx_export(True):
model(*inps)
with te.export.onnx_export(True):
torch.onnx.export(
model,
inps,
fname,
dynamo=True,
custom_translation_table=te_translation_table,
verbose=True,
dynamic_shapes=dynamic_shapes,
input_names=input_names,
output_names=output_names,
optimize=inps[0].dtype
!= torch.bfloat16, # optimizer does not work with bfloat16 yet - will need to change that after onnxscript supports bfloat16
)
def to_numpy(tensor):
if isinstance(tensor, torch.Tensor):
if tensor.dtype == torch.bfloat16:
tensor = tensor.type(torch.float32)
tensor = tensor.detach().cpu().numpy()
return tensor
def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int):
"""Initialize the FP8 quantization scales in module"""
module.init_fp8_metadata(num_gemms)
for quantizer in module.quantizers["scaling_fwd"]:
quantizer.scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale
def te_infer(
model: torch.nn.Module,
inps: Union[Tuple[torch.Tensor], torch.Tensor],
is_fp8: bool,
fp8_recipe: recipe.Recipe,
):
"""Transformer Engine forward propagation."""
with torch.inference_mode(), te.fp8_autocast(
enabled=is_fp8, fp8_recipe=fp8_recipe
), warnings.catch_warnings():
te_outputs = model(*inps if isinstance(inps, tuple) else (inps,))
if not isinstance(te_outputs, tuple):
te_outputs = (te_outputs,)
return te_outputs
def compare_outputs(
onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname
):
"""Compare ORT and TE outputs."""
assert len(onnx_outputs) == len(te_outputs)
# Compare ORT and PyTorch outputs.
for onnx_output, te_output in zip(onnx_outputs, te_outputs):
# np.isclose: abs(a - b) <= (atol + rtol * abs(b))
te_output = to_numpy(te_output)
onnx_output = to_numpy(onnx_output)
ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol)
mismatches = ac.nonzero()
mismatched_ids = [loc for loc in zip(*mismatches)]
if mismatched_ids:
# Log some information in case of error.
print("*" * 100)
nb_errors = len(mismatched_ids)
nb_vals = min(nb_errors, max_errors_printed)
print(f"Detected {nb_errors} diverging values (output shape={onnx_output.shape})")
print(f"Showing first {nb_vals} errors (ONNX -- TE):")
abs_err = np.abs(onnx_output - te_output)
errors = abs_err[mismatches]
for loc in mismatched_ids[:nb_vals]:
ref = te_output[loc]
print(
f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} >"
f" {atol + rtol * abs(ref)}"
)
print(f"Max error: {np.max(errors)}")
if nb_errors > allow_cnt_errors:
raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors")
def serialize_inputs_outputs(
fname: str,
inputs: Union[Tuple[torch.Tensor], torch.Tensor],
te_outputs: List[torch.Tensor],
input_names: Optional[List[str]] = None,
output_names: Optional[List[str]] = None,
):
if not SAVE_TEST_IO:
return
fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
input_names = input_names or ["input"]
output_names = output_names or ["output"]
inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,)
named_inputs = zip(input_names, inputs)
input_data = [{k: v.cpu() for k, v in named_inputs if v is not None}]
json_fname = fname[: -len(".onnx")] + "_inputs.json"
save_json(input_data, json_fname, description="custom input data")
json_fname = fname[: -len(".onnx")] + "_output.json"
named_outputs = zip(output_names, te_outputs)
output_data = {k: v.detach().cpu() for k, v in named_outputs if v is not None}
custom_outputs = RunResults()
custom_outputs.add([output_data], runner_name="custom_runner")
custom_outputs.save(json_fname)
def validate_result(
fname: str,
inps: Union[Tuple[torch.Tensor], torch.Tensor],
model: torch.nn.Module,
atol: float = 1.0e-8, # np.isclose default atol
rtol: float = 1.0e-5, # np.isclose default rtol
max_errors_printed: int = 10,
is_fp8: bool = False,
allow_cnt_errors: int = 0,
input_names: List[str] = None,
output_names: List[str] = None,
te_outputs: List[torch.Tensor] = None,
):
"""Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX
representation using ONNX Runtime (ORT) and ensure they are close.
The purpose of the output comparison is to validate that TE models are converted to
their correct ONNX representation by testing that TE and ORT outputs match within some
small threshold (allowing for finite precision errors).
Argument `allow_cnt_errors` reduces test failure noise due to spurious errors by ignoring,
a very small number (0-3) of outliers. This is fine to do because these outliers are due to
small kernel implementation differences between TE and ORT and do not imply an incorrect ONNX
representation (the tests assume both ORT or TE kernels are correct).
Argument `te_outputs` can be used to provide pre-computed TE outputs.
"""
def create_ort_session(fname: str, is_fp8: bool):
def load_custom_ops(session_opts: ort.SessionOptions):
"""For FP8 validation with ORT we need to load our custom FP8 Q/DQ extension."""
session_opts.register_custom_ops_library(get_library_path())
print("registered custom FP8 Q/DQ ops!")
"""Create an ONNX Runtime session for validation."""
kwargs = {"providers": ["CUDAExecutionProvider", "CPUExecutionProvider"]}
if is_fp8:
sess_options = ort.SessionOptions()
load_custom_ops(sess_options)
kwargs["sess_options"] = sess_options
s = ort.InferenceSession(fname, **kwargs)
return s
def create_ort_input_dict(session, inputs):
inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,)
input_names = [x.name for x in session.get_inputs()]
inps = [to_numpy(x) for x in inputs if x is not None]
inp_dict = dict(zip(input_names, inps))
return inp_dict
input_names = input_names or ["input"]
output_names = output_names or ["output"]
# Run ORT session and TE model.
fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
if not te_outputs:
te_outputs = te_infer(model, inps, is_fp8)
ort_s = create_ort_session(fname, is_fp8)
input_feed = create_ort_input_dict(ort_s, inps)
onnx_outputs = ort_s.run(None, input_feed=input_feed)
compare_outputs(
onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname
)
def create_meta(scale_factor: float, size: int = 1):
meta = tex.FP8TensorMeta()
meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
return meta
def dtype2str(dtype: torch.dtype, fake_bf16_io=False):
if fake_bf16_io:
assert dtype == torch.bfloat16
return "_fake_bf16"
return {
torch.float32: "_fp32",
torch.float16: "_fp16",
torch.bfloat16: "_bf16",
}[dtype]
def as_te_type(dtype: torch.dtype):
return {
torch.float32: tex.DType.kFloat32,
torch.float16: tex.DType.kFloat16,
torch.bfloat16: tex.DType.kBFloat16,
}[dtype]
def get_attn_mask_str(use_mask, attn_mask_type):
# See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names.
if attn_mask_type is None:
return "_mask" if use_mask else "_no-mask"
attn_mask_str = "_arbitrary-no-mask"
attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str
attn_mask_str = (
"_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str
)
return attn_mask_str
"""
Test cases begin here.
"""
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
# Returning the bias is a TE fusion optimization we don't care about.
@pytest.mark.parametrize("return_bias", [True, False])
@pytest.mark.parametrize(
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, False),
(torch.float16, True),
# Todo: cannot configure BF16 when bias is disabled (ORT issue?)
(torch.bfloat16, False),
# Todo: cannot configure BF16 when bias is enabled (ORT issue?)
(torch.bfloat16, True),
],
)
def test_export_linear(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
use_bias: bool,
return_bias: bool,
precision: torch.dtype,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled")
# Set dimensions (these are arbitrary).
batch_size = 4
in_features = 64
out_features = 64
hidden_size = 64
class Test_Linear(nn.Module):
def __init__(self, in_features, out_features, use_bias, return_bias, precision):
super().__init__()
self.linear = te.Linear(
in_features,
out_features,
bias=use_bias,
return_bias=return_bias,
params_dtype=precision,
)
def forward(self, inp):
ret = self.linear(inp)
return ret
inp = torch.randn(batch_size, hidden_size, in_features, device="cuda", dtype=precision)
fp8_str = "_fp8" if fp8_recipe is not None else ""
bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision)
fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to(
device="cuda"
)
# dynamic shape
bs = torch.export.Dim("bs", min=2, max=1256)
do_export(
model,
inp,
fname,
fp8_recipe,
dynamic_shapes={"inp": {0: bs}},
)
te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16,):
return
if fp8_recipe is None:
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
else:
validate_result(
fname, inp, model, atol=1e-2, is_fp8=fp8_recipe is not None, te_outputs=te_outputs
)
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize(
"precision",
[
torch.float32,
torch.float16,
torch.bfloat16,
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
precision: torch.dtype,
zero_centered_gamma: bool,
normalization: str,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Set dimensions (these are arbitrary).
batch_size = 4
in_features = 64
out_features = 256
hidden_size = 256
inp = torch.ones(batch_size, in_features, out_features, device="cuda", dtype=precision)
fp8_str = "_fp8" if fp8_recipe is not None else ""
high_prec_str = dtype2str(precision)
fname = f"te.layernorm_linear{fp8_str}{high_prec_str}.onnx"
with torch.no_grad():
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
layernorm_cls = te.LayerNorm if normalization == "LayerNorm" else te.RMSNorm
model = layernorm_cls(
hidden_size,
params_dtype=precision,
zero_centered_gamma=zero_centered_gamma,
).to(device="cuda")
# dynamic shape
bs = torch.export.Dim("bs", min=2, max=1256)
do_export(model, inp, fname, fp8_recipe, dynamic_shapes={"input": {0: bs}})
te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16,):
return
if fp8_recipe is None:
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
elif precision != torch.bfloat16:
validate_result(
fname,
inp,
model,
atol=1e-3,
is_fp8=fp8_recipe is not None,
te_outputs=te_outputs,
)
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("return_bias", [True, False])
@pytest.mark.parametrize("return_layernorm_output", [True, False])
@pytest.mark.parametrize(
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True),
(torch.bfloat16, False),
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_linear(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
use_bias: bool,
return_bias: bool,
return_layernorm_output: bool,
precision: torch.dtype,
zero_centered_gamma: bool,
normalization: str,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled")
# Set dimensions (these are arbitrary).
in_features = 64
out_features = 256
hidden_size = 256
inp = torch.randn(in_features, out_features, device="cuda", dtype=precision)
fp8_str = "_fp8" if fp8_recipe is not None else ""
bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision)
fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with torch.no_grad():
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
model = te.LayerNormLinear(
hidden_size,
3 * hidden_size,
bias=use_bias,
return_bias=return_bias,
return_layernorm_output=return_layernorm_output,
params_dtype=precision,
zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
).to(device="cuda")
if fp8_recipe is not None:
set_layer_scale(model, scale_factor, num_gemms=2)
do_export(model, inp, fname, fp8_recipe)
te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16,):
return
if fp8_recipe is None:
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
elif precision != torch.bfloat16:
validate_result(
fname,
inp,
model,
atol=1e-3,
is_fp8=fp8_recipe is not None,
te_outputs=te_outputs,
)
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("return_bias", [True, False])
@pytest.mark.parametrize("return_layernorm_output", [True, False])
@pytest.mark.parametrize(
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True),
(torch.bfloat16, False),
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("activation", supported_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_mlp(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
use_bias: bool,
return_bias: bool,
return_layernorm_output: bool,
precision: torch.dtype,
zero_centered_gamma: bool,
activation: str,
normalization: str,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled")
# Set dimensions (these are arbitrary).
in_features = 64
out_features = 256
hidden_size = 256
ffn_hidden_size = 256
inp = torch.randn(in_features, out_features, device="cuda", dtype=precision)
fp8_str = "_fp8" if fp8_recipe is not None else ""
bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision)
fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx"
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
model = te.LayerNormMLP(
hidden_size,
ffn_hidden_size,
bias=use_bias,
return_bias=return_bias,
return_layernorm_output=return_layernorm_output,
params_dtype=precision,
zero_centered_gamma=zero_centered_gamma,
activation=activation,
normalization=normalization,
).to(device="cuda")
if fp8_recipe is not None:
set_layer_scale(model, scale_factor, num_gemms=2)
do_export(model, inp, fname, fp8_recipe)
te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16,):
return
atol = (
2e-2 if fp8_recipe is not None else (5e-1 if activation == "swiglu" else 1e-3)
) # TODO(pgadzinski) - check 2e-2
validate_result(
fname, inp, model, atol=atol, is_fp8=fp8_recipe is not None, te_outputs=te_outputs
)
@pytest.mark.parametrize(
"precision, use_mask, attn_mask_type",
[
(torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
],
)
def test_export_core_attention(
seed_default_rng,
set_max_seq_len,
precision: torch.dtype,
use_mask: bool,
attn_mask_type: str,
):
# Set dimensions (these are arbitrary).
seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64)
qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels)
qkv_format = "sbhd"
query_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
key_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
value_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
input_names = ["query", "key", "value", "attention_mask"]
attention_mask = None
if use_mask:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(batch_size, 1, 1, seq_len, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (query_layer, key_layer, value_layer, attention_mask)
mask_str = get_attn_mask_str(use_mask, attn_mask_type)
high_prec_str = dtype2str(precision)
fname = f"te.core_attention{mask_str}{high_prec_str}.onnx"
model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
attention_dropout=0.5,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
).to(device="cuda")
do_export(model, inp, fname, input_names=input_names, fp8_recipe=None)
te_outputs = te_infer(model, inp, is_fp8=False, fp8_recipe=None)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision in (torch.bfloat16,):
return
validate_result(
fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs
)
test_configs_multihead_attention = [
# "use_mask, attn_mask_type"
(False, "no_mask"), # calls ScaledSoftmax
(True, "arbitrary"), # calls ScaledMaskedSoftmax
]
test_configs_attention_type = [
# "input_layernorm, attention_type, fuse_qkv_params"
(True, "self", True),
(False, "self", True),
(True, "self", False),
(False, "self", False),
(True, "cross", True),
(False, "cross", True),
(True, "cross", False),
(False, "cross", False),
]
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize(
"input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type
)
def test_export_multihead_attention(
seed_default_rng,
set_max_seq_len,
fp8_recipe: recipe.Recipe,
use_mask: bool,
attn_mask_type: str,
precision: torch.dtype,
return_layernorm_output: bool,
input_layernorm: bool,
attention_type: str,
fuse_qkv_params: bool,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
hidden_size = 256
sequence_length = 128
batch_size = 4
num_attention_heads = 32
kv_channels = 8
attention_dropout = 0.1
layernorm_epsilon = 1e-5
init_method = output_layer_init_method = get_default_init_method()
attention_args = (
hidden_size,
num_attention_heads,
kv_channels,
attention_dropout,
layernorm_epsilon,
init_method,
output_layer_init_method,
)
hidden_states_context = torch.randn(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(
batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision
)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
encoder_output = None
if attention_type == "cross":
encoder_output = torch.randn(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
fp8_str = "_fp8" if fp8_recipe is not None else ""
dtype_str = dtype2str(precision)
attn_type_str = "_self-attention" if attention_type == "self" else "_cross-attention"
fuse_qkv_str = "_fused-qkv" if fuse_qkv_params else ""
attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
input_ln_str = "_input-ln" if input_layernorm else ""
fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx"
model = te.MultiheadAttention(
*attention_args,
attn_mask_type=attn_mask_type,
params_dtype=precision,
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
attention_type=attention_type,
fuse_qkv_params=fuse_qkv_params,
return_bias=True,
).to(device="cuda")
inp_context = (hidden_states_context, attention_mask, encoder_output)
input_names = ["hidden_states", "attention_mask", "encoder_output"]
output_names = ["attention_output", "attention_bias"]
seq = torch.export.Dim("seq", min=2, max=1256)
bs = torch.export.Dim("bs", min=2, max=1256)
do_export(
model,
inp_context,
fname,
fp8_recipe,
input_names=input_names,
output_names=output_names,
dynamic_shapes={
"hidden_states": {0: seq, 1: bs},
"attention_mask": {2: seq, 0: bs} if use_mask else None,
"encoder_output": {0: seq, 1: bs} if attention_type == "cross" else None,
},
)
te_outputs = te_infer(model, inp_context, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(
fname, inp_context, te_outputs, input_names=input_names, output_names=output_names
)
if precision in (torch.bfloat16,):
return
if fp8_recipe is None:
validate_result(
fname,
inp_context,
model,
atol=1e-3,
input_names=input_names,
output_names=output_names,
te_outputs=te_outputs,
)
else:
validate_result(
fname,
inp_context,
model,
atol=1e-2,
is_fp8=fp8_recipe is not None,
input_names=input_names,
output_names=output_names,
allow_cnt_errors=3,
te_outputs=te_outputs,
)
# In GPT generative phase (inference) the input sequence is smaller than the maximum
# allowed sequence length and we want to test this condition.
# Pretend that we're in generative phase when it makes sense (causal mask and self-attention).
is_generative_phase = attn_mask_type == "causal" and attention_type == "self"
if is_generative_phase:
seq_len_offset = 8
hidden_states_generative = torch.randn(
sequence_length - seq_len_offset,
batch_size,
hidden_size,
dtype=precision,
device="cuda",
)
inp_generative = (hidden_states_generative, attention_mask, encoder_output)
if fp8_recipe is None:
validate_result(
fname,
inp_generative,
model,
atol=1e-3,
input_names=input_names,
output_names=output_names,
)
else:
validate_result(
fname,
inp_generative,
model,
atol=1e-2,
is_fp8=fp8_recipe is not None,
input_names=input_names,
output_names=output_names,
allow_cnt_errors=3,
)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("output_layernorm", [True, False])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("activation", supported_activations)
def test_export_transformer_layer(
seed_default_rng,
set_max_seq_len,
fp8_recipe: recipe.Recipe,
use_mask: bool,
attn_mask_type: str,
output_layernorm: bool,
precision: torch.dtype,
fuse_qkv_params: bool,
zero_centered_gamma: bool,
activation: str,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Layer configuration
hidden_size = 64
sequence_length = 128
batch_size = 1
ffn_hidden_size = 256
num_attention_heads = 4
input_tensor = torch.rand(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
input_names = ["input", "attention_mask"]
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(
batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision
)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (input_tensor, attention_mask)
fp8_str = "_fp8" if fp8_recipe is not None else ""
fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
high_prec_str = dtype2str(precision)
attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}_{activation}.onnx"
model = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma,
activation=activation,
).to(device="cuda")
do_export(model, inp, fname, fp8_recipe, input_names=input_names)
te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(
fname,
inp,
te_outputs,
input_names=input_names,
)
if precision in (torch.bfloat16,):
return
atol = 5e-1 if fp8_recipe is not None else (5e-1 if activation == "swiglu" else 5e-3)
validate_result(
fname,
inp,
model,
atol=atol,
is_fp8=fp8_recipe is not None,
input_names=input_names,
te_outputs=te_outputs,
)
@skip_FP8
@skip_MXFP8
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("zero_centered_gamma", [True])
def test_export_gpt_generation(
seed_default_rng,
set_max_seq_len,
fp8_recipe: recipe.Recipe,
precision: torch.dtype,
zero_centered_gamma: bool,
):
"""Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths.
"""
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Layer configuration
hidden_size = 64
sequence_length = 128
batch_size = 4
ffn_hidden_size = 256
num_attention_heads = 4
attention_mask = None
use_mask = True
attn_mask_type = "causal"
fuse_qkv_params = True
output_layernorm = False
fp8_str = "_fp8" if fp8_recipe is not None else ""
fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
high_prec_str = dtype2str(precision)
attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
fname = f"te.transformer_layer_generative{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}.onnx"
model = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma,
).to(device="cuda")
# "Context phase": use full input sequence length
input_names = ["input"]
output_names = ["output"]
input_tensor = torch.rand(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
inp = (input_tensor,)
# dynamic shape
seq = torch.export.Dim("seq", min=2, max=1256)
bs = torch.export.Dim("bs", min=2, max=1256)
do_export(
model,
inp,
fname,
fp8_recipe,
dynamic_shapes={"hidden_states": {0: seq, 1: bs}},
)
te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(
fname, inp, te_outputs, input_names=input_names, output_names=output_names
)
if precision not in (torch.bfloat16,):
validate_result(
fname,
inp,
model,
atol=1e-2,
is_fp8=fp8_recipe is not None,
input_names=input_names,
te_outputs=te_outputs,
)
# "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8 and for MXFP8 we need to pad to mult of 32.
sequence_length = 1 if fp8_recipe is None else 32
input_tensor = torch.rand(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
inp = (input_tensor, attention_mask)
te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision not in (torch.bfloat16,):
validate_result(
fname,
inp,
model,
atol=1e-2,
is_fp8=fp8_recipe is not None,
input_names=input_names,
te_outputs=te_outputs,
)
@pytest.mark.parametrize("enabled", [True, False])
def test_export_ctx_manager(enabled):
assert is_in_onnx_export_mode() == False
with te.onnx_export(enabled):
assert is_in_onnx_export_mode() == enabled
assert is_in_onnx_export_mode() == False
...@@ -61,22 +61,26 @@ class TestParallelCrossEntropy: ...@@ -61,22 +61,26 @@ class TestParallelCrossEntropy:
test_loss = self.test_loss_func( test_loss = self.test_loss_func(
self.input_test, self.tar_test, label_smoothing, reduce_loss, None self.input_test, self.tar_test, label_smoothing, reduce_loss, None
) )
if reduce_loss:
test_loss.backward()
ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref)
# Handle backward pass based on the test scenario
if reduce_loss: if reduce_loss:
test_loss.backward()
ref_loss.backward() ref_loss.backward()
else:
test_loss.sum().backward()
ref_loss.sum().backward()
test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss
torch.testing.assert_close(test_loss, ref_loss, check_dtype=False)
if ignore_idx: if ignore_idx:
print(test_loss, ref_loss) print(test_loss, ref_loss)
if reduce_loss:
torch.testing.assert_close( # Compare gradients when backward pass was called
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad torch.testing.assert_close(
) torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad
)
self.input_test = None self.input_test = None
self.input_ref = None self.input_ref = None
......
...@@ -326,33 +326,37 @@ def _test_permutation_index_map( ...@@ -326,33 +326,37 @@ def _test_permutation_index_map(
te_unpermute_output_ = te_unpermute_output.float() te_unpermute_output_ = te_unpermute_output.float()
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
torch.testing.assert_close( if not BENCHMARK:
pytorch_permute_output.float(),
te_permute_output_,
msg=f"Mismatch in te_permute fwd",
)
torch.testing.assert_close(
pytorch_permute_fwd_input.grad.float(),
te_permute_fwd_input_grad,
msg=f"Mismatch in te_permute bwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_output.float(),
te_unpermute_output_,
msg=f"Mismatch in te_unpermute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_fwd_input.grad.float(),
te_unpermute_fwd_input_grad,
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if with_probs:
torch.testing.assert_close( torch.testing.assert_close(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols pytorch_permute_output.float(),
te_permute_output_,
msg=f"Mismatch in te_permute fwd",
) )
torch.testing.assert_close(
pytorch_permute_fwd_input.grad.float(),
te_permute_fwd_input_grad,
msg=f"Mismatch in te_permute bwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_output.float(),
te_unpermute_output_,
msg=f"Mismatch in te_unpermute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_fwd_input.grad.float(),
te_unpermute_fwd_input_grad,
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if with_probs:
torch.testing.assert_close(
probs.grad.float(),
te_probs.grad.float(),
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if not pytorch_permute_fwd_input.numel(): if not pytorch_permute_fwd_input.numel():
print("Empty pytorch_permute_fwd_input activation test passed.") print("Empty pytorch_permute_fwd_input activation test passed.")
...@@ -538,34 +542,38 @@ def _test_permutation_mask_map( ...@@ -538,34 +542,38 @@ def _test_permutation_mask_map(
te_unpermute_output_ = te_unpermute_output.float() te_unpermute_output_ = te_unpermute_output.float()
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
torch.testing.assert_close( if not BENCHMARK:
pytorch_permute_output.float(), torch.testing.assert_close(
te_permute_output_, pytorch_permute_output.float(),
msg=f"Mismatch in te_permute fwd", te_permute_output_,
**tols, msg=f"Mismatch in te_permute fwd",
) **tols,
torch.testing.assert_close( )
pytorch_permute_fwd_input.grad.float(),
te_permute_fwd_input_grad,
msg=f"Mismatch in te_permute bwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_output.float(),
te_unpermute_output_,
msg=f"Mismatch in te_unpermute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_fwd_input.grad.float(),
te_unpermute_fwd_input_grad,
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if with_probs:
torch.testing.assert_close( torch.testing.assert_close(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols pytorch_permute_fwd_input.grad.float(),
te_permute_fwd_input_grad,
msg=f"Mismatch in te_permute bwd",
**tols,
) )
torch.testing.assert_close(
pytorch_unpermute_output.float(),
te_unpermute_output_,
msg=f"Mismatch in te_unpermute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_fwd_input.grad.float(),
te_unpermute_fwd_input_grad,
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if with_probs:
torch.testing.assert_close(
probs.grad.float(),
te_probs.grad.float(),
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if not pytorch_permute_fwd_input.numel(): if not pytorch_permute_fwd_input.numel():
print("Empty pytorch_permute_fwd_input activation test passed.") print("Empty pytorch_permute_fwd_input activation test passed.")
...@@ -827,18 +835,19 @@ def _test_moe_chunk_sort( ...@@ -827,18 +835,19 @@ def _test_moe_chunk_sort(
te_output_ = te_output.float() te_output_ = te_output.float()
te_fwd_input_grad = te_fwd_input.grad.float() te_fwd_input_grad = te_fwd_input.grad.float()
torch.testing.assert_close( if not BENCHMARK:
pytorch_output.float(), torch.testing.assert_close(
te_output_, pytorch_output.float(),
msg=f"Mismatch in te_permute fwd", te_output_,
**tols, msg=f"Mismatch in te_permute fwd",
) **tols,
torch.testing.assert_close( )
pytorch_fwd_input.grad.float(), torch.testing.assert_close(
te_fwd_input_grad, pytorch_fwd_input.grad.float(),
msg=f"Mismatch in te_permute bwd", te_fwd_input_grad,
**tols, msg=f"Mismatch in te_permute bwd",
) **tols,
)
if not pytorch_fwd_input.numel(): if not pytorch_fwd_input.numel():
print("Empty pytorch_fwd_input activation test passed.") print("Empty pytorch_fwd_input activation test passed.")
...@@ -887,6 +896,7 @@ def _test_permutation_mask_map_alongside_probs( ...@@ -887,6 +896,7 @@ def _test_permutation_mask_map_alongside_probs(
topK, topK,
num_out_tokens, num_out_tokens,
tp_size, tp_size,
BENCHMARK=False,
): ):
if topK > num_expert: if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.") pytest.skip("topK should be smaller than the number of experts.")
...@@ -1016,21 +1026,73 @@ def _test_permutation_mask_map_alongside_probs( ...@@ -1016,21 +1026,73 @@ def _test_permutation_mask_map_alongside_probs(
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float() te_unpermute_output_ = te_unpermute_output.float()
torch.testing.assert_close( if not BENCHMARK:
pytorch_unpermute_output.float(), torch.testing.assert_close(
te_unpermute_output_, pytorch_unpermute_output.float(),
msg=f"Mismatch in fused_unpermute fwd", te_unpermute_output_,
**tols, msg=f"Mismatch in fused_unpermute fwd",
) **tols,
torch.testing.assert_close( )
pytorch_permute_fwd_input.grad.float(), torch.testing.assert_close(
te_permute_fwd_input_grad, pytorch_permute_fwd_input.grad.float(),
msg=f"Mismatch in fused_permute bwd", te_permute_fwd_input_grad,
**tols, msg=f"Mismatch in fused_permute bwd",
) **tols,
torch.testing.assert_close( )
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in prob grad", **tols torch.testing.assert_close(
) probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in prob grad", **tols
)
if BENCHMARK:
t1 = perf_test_cuda_kernel(
lambda: te_permute_with_probs(
te_permute_fwd_input, te_probs, routing_map, num_out_tokens=num_out_tokens
)
)
print(f"permute\t\tfwd: TE: {t1:.3f} ms")
te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs(
te_permute_fwd_input,
te_probs,
routing_map,
num_out_tokens=num_out_tokens,
)
te_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
te_permute_output,
te_permute_bwd_input,
forward_input=[te_permute_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"permute\t\tbwd: TE: {t2:.3f} ms")
chunk_sort_fwd_input = te_permute_output.detach()
chunk_sort_fwd_input.requires_grad_(True)
chunk_sort_fwd_probs = te_permuted_probs.detach()
chunk_sort_fwd_probs.requires_grad_(True)
t1 = perf_test_cuda_kernel(
lambda: te_sort_chunks_by_index_with_probs(
chunk_sort_fwd_input, chunk_sort_fwd_probs, split_sizes_cuda, sorted_idxs_cuda
)
)
print(f"chunk sort\t\tfwd: TE: {t1:.3f} ms")
chunk_sort_output, _ = te_sort_chunks_by_index_with_probs(
chunk_sort_fwd_input, chunk_sort_fwd_probs, split_sizes_cuda, sorted_idxs_cuda
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
chunk_sort_output,
te_permute_bwd_input,
forward_input=[chunk_sort_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"chunk sort\t\tbwd: TE: {t2:.3f} ms")
def perf_test_cuda_kernel(cuda_kernel_fn): def perf_test_cuda_kernel(cuda_kernel_fn):
...@@ -1063,7 +1125,7 @@ if is_bf16_compatible(): ...@@ -1063,7 +1125,7 @@ if is_bf16_compatible():
@pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("num_out_tokens", [None, 2039])
...@@ -1092,7 +1154,7 @@ def test_permutation_index_map( ...@@ -1092,7 +1154,7 @@ def test_permutation_index_map(
@pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("num_out_tokens", [None, 2039])
...@@ -1138,7 +1200,7 @@ def test_permutation_mask_map_empty_input(te_dtype): ...@@ -1138,7 +1200,7 @@ def test_permutation_mask_map_empty_input(te_dtype):
@pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("num_out_tokens", [None, 2039])
...@@ -1193,7 +1255,7 @@ fp8_recipes = [ ...@@ -1193,7 +1255,7 @@ fp8_recipes = [
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) @pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("num_tokens", [2048]) @pytest.mark.parametrize("num_tokens", [2048])
@pytest.mark.parametrize("num_expert", [8, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("num_out_tokens", [None, 2039])
...@@ -1225,7 +1287,7 @@ def test_permutation_mask_map_fp8( ...@@ -1225,7 +1287,7 @@ def test_permutation_mask_map_fp8(
@pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
def test_permutation_index_map_topk1_no_probs( def test_permutation_index_map_topk1_no_probs(
te_dtype, te_dtype,
...@@ -1252,7 +1314,7 @@ def test_permutation_index_map_topk1_no_probs( ...@@ -1252,7 +1314,7 @@ def test_permutation_index_map_topk1_no_probs(
@pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
def test_permutation_mask_map_topk1_no_probs( def test_permutation_mask_map_topk1_no_probs(
te_dtype, te_dtype,
...@@ -1279,7 +1341,7 @@ def test_permutation_mask_map_topk1_no_probs( ...@@ -1279,7 +1341,7 @@ def test_permutation_mask_map_topk1_no_probs(
@pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("tp_size", [1, 2, 8]) @pytest.mark.parametrize("tp_size", [1, 2, 8])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
def test_chunk_permutation( def test_chunk_permutation(
...@@ -1372,5 +1434,108 @@ def test_permutation_single_case(): ...@@ -1372,5 +1434,108 @@ def test_permutation_single_case():
) )
def benchmark_single_case(
te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
):
torch.cuda.nvtx.range_push(
f"{num_tokens}-{num_expert}-{hidden_size}-{topK}-{ep_size}-{tp_size}"
)
torch.cuda.nvtx.range_push("permutation_index_map_with_probs")
_test_permutation_index_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=True,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_mask_map_with_probs")
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=True,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_mask_map_without_probs")
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=False,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_mask_map_alongside_probs")
_test_permutation_mask_map_alongside_probs(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
tp_size=tp_size,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop()
def benchmark_multiple_cases():
print("GPU:", torch.cuda.get_device_name(0))
# te_dtype = tex.DType.kFloat32
# te_dtype = tex.DType.kFloat16
te_dtype = tex.DType.kBFloat16
ep_size = 64
tp_size = 2
num_tokens = 4096
num_expert = 256
hidden_size = 7168
topK = 8
num_out_tokens = num_tokens * topK
benchmark_single_case(
te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
)
ep_size = 8
tp_size = 1
num_tokens = 8192 * 2
num_expert = 128
hidden_size = 4096
topK = 6
num_out_tokens = num_tokens * topK
benchmark_single_case(
te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
)
ep_size = 64
tp_size = 2
num_tokens = 16384
num_expert = 4
hidden_size = 7168
topK = 1
num_out_tokens = num_tokens * topK
benchmark_single_case(
te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
)
if __name__ == "__main__": if __name__ == "__main__":
test_permutation_single_case() benchmark_multiple_cases()
...@@ -47,7 +47,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -47,7 +47,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import checkpoint
from test_numerics import reset_rng_states, dtype_tols from utils import dtype_tols
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
...@@ -56,6 +56,28 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( ...@@ -56,6 +56,28 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
) )
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))
if NVTE_TEST_NVINSPECT_ENABLED:
# The sanity tests should work the same,
# when debug=True. I fed them with dummy feature
# to prevent switching off debug, which can happen if
# no feature is active.
import nvdlfw_inspect.api as debug_api
debug_api.initialize(
os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
)
def create_meta(scale_factor: float, size: int = 1): def create_meta(scale_factor: float, size: int = 1):
meta = tex.FP8TensorMeta() meta = tex.FP8TensorMeta()
...@@ -90,6 +112,13 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor: ...@@ -90,6 +112,13 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
return torch.min(amax_history, dim=0).values return torch.min(amax_history, dim=0).values
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
global _cpu_rng_state, _cuda_rng_state
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@dataclass @dataclass
class ModelConfig: class ModelConfig:
"""Transformer model configuration""" """Transformer model configuration"""
...@@ -529,6 +558,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba ...@@ -529,6 +558,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias): def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
pytest.skip("Quantized model parameters are not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size ffn_hidden_size = 4 * config.hidden_size
num_tokens = bs * config.seq_len num_tokens = bs * config.seq_len
...@@ -570,6 +601,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -570,6 +601,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
def test_sanity_grouped_linear( def test_sanity_grouped_linear(
dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
): ):
if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
pytest.skip("FP8 model parameters are not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size ffn_hidden_size = 4 * config.hidden_size
# Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
...@@ -682,6 +715,8 @@ def test_sanity_gpt( ...@@ -682,6 +715,8 @@ def test_sanity_gpt(
parallel_attention_mlp, parallel_attention_mlp,
cpu_offload, cpu_offload,
): ):
if cpu_offload and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("CPU offload is not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
...@@ -1367,6 +1402,8 @@ def test_inference_mode( ...@@ -1367,6 +1402,8 @@ def test_inference_mode(
quantization: Optional[str], quantization: Optional[str],
) -> None: ) -> None:
"""Test heuristics for initializing quantized weights""" """Test heuristics for initializing quantized weights"""
if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
pytest.skip("Quantized model parameters are not supported in debug mode.")
# Tensor dimensions # Tensor dimensions
sequence_length = 32 sequence_length = 32
......
...@@ -93,6 +93,7 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]: ...@@ -93,6 +93,7 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
if name in ("fp8", "fp8_delayed_scaling"): if name in ("fp8", "fp8_delayed_scaling"):
return transformer_engine.common.recipe.DelayedScaling( return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3, fp8_format=transformer_engine.common.recipe.Format.E4M3,
amax_history_len=8,
) )
if name == "fp8_current_scaling": if name == "fp8_current_scaling":
return transformer_engine.common.recipe.Float8CurrentScaling( return transformer_engine.common.recipe.Float8CurrentScaling(
......
...@@ -158,6 +158,9 @@ if(USE_CUDA) ...@@ -158,6 +158,9 @@ if(USE_CUDA)
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu fused_rope/fused_rope.cu
fused_router/fused_moe_aux_loss.cu
fused_router/fused_score_for_moe_aux_loss.cu
fused_router/fused_topk_with_score_function.cu
recipe/current_scaling.cu recipe/current_scaling.cu
recipe/delayed_scaling.cu recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu recipe/fp8_block_scaling.cu
...@@ -211,6 +214,9 @@ else() ...@@ -211,6 +214,9 @@ else()
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu fused_rope/fused_rope.cu
fused_router/fused_moe_aux_loss.cu
fused_router/fused_score_for_moe_aux_loss.cu
fused_router/fused_topk_with_score_function.cu
recipe/current_scaling.cu recipe/current_scaling.cu
recipe/delayed_scaling.cu recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu recipe/fp8_block_scaling.cu
......
...@@ -100,6 +100,16 @@ bool has_mnnvl_fabric(int device_id) { ...@@ -100,6 +100,16 @@ bool has_mnnvl_fabric(int device_id) {
} }
return false; return false;
#else #else
// Check run-time CUDA version
if (transformer_engine::cuda::cudart_version() < 12040) {
if (getenv("NVTE_UBDEBUG")) {
printf(
"TransformerEngine does not support multi-node NVLINK "
"since it is not being run with CUDA version >= 12.4.\n");
}
return false;
}
bool mnnvl_fabric_support = false; bool mnnvl_fabric_support = false;
CUdevice dev; CUdevice dev;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &dev, device_id); NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &dev, device_id);
......
...@@ -248,7 +248,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -248,7 +248,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) ||
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
cudnn_runtime_version >= 91100))) && cudnn_runtime_version >= 91100)) &&
// 9.11 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(!(cudnn_runtime_version == 91100 && is_training && sm_arch_ == 90 && head_dim_qk >= 128 &&
head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) &&
head_dim_qk != head_dim_v))) &&
// bias type // bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
(cudnn_runtime_version >= 8906 && (cudnn_runtime_version >= 8906 &&
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
namespace transformer_engine { namespace transformer_engine {
namespace kv_cache { namespace kv_cache {
constexpr int block_size = 1024;
template <typename dtype> template <typename dtype>
__global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *batch_indices, __global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *batch_indices,
int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k, int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k,
...@@ -22,21 +24,29 @@ __global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *bat ...@@ -22,21 +24,29 @@ __global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *bat
actual_b = i + 1; actual_b = i + 1;
} }
} }
bool flag = (batch_indices[0] != 0);
for (int batch_idx = 0; batch_idx < actual_b; batch_idx++) { for (int batch_idx = 0; batch_idx < actual_b; batch_idx++) {
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; if (flag || ((batch_indices[batch_idx] - batch_indices[0]) != batch_idx)) {
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; int num_tokens = (cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]) -
for (int token_idx = blockIdx.x; token_idx < cached_len - new_len; token_idx += gridDim.x) { (cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]);
int num_elts_k = h_kv * d_k; int num_elts_k = h_kv * d_k;
int num_elts_v = h_kv * d_v; int num_elts_v = h_kv * d_v;
int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k; int num_elts = max(num_elts_k, num_elts_v);
int k_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_k; for (int token_idx = blockIdx.x; token_idx < num_tokens; token_idx += gridDim.x) {
int v_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_v; int src_offset = batch_indices[batch_idx] * max_seq_len + token_idx;
int v_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_v; int des_offset = batch_idx * max_seq_len + token_idx;
for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) { dtype *k_cache_src_offset = k_cache + src_offset * num_elts_k;
*(k_cache + k_cache_des_offset + i) = *(k_cache + k_cache_src_offset + i); dtype *k_cache_des_offset = k_cache + des_offset * num_elts_k;
} dtype *v_cache_src_offset = v_cache + src_offset * num_elts_v;
for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) { dtype *v_cache_des_offset = v_cache + des_offset * num_elts_v;
*(v_cache + v_cache_des_offset + i) = *(v_cache + v_cache_src_offset + i); for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
if (i < num_elts_k) {
*(k_cache_des_offset + i) = *(k_cache_src_offset + i);
}
if (i < num_elts_v) {
*(v_cache_des_offset + i) = *(v_cache_src_offset + i);
}
}
} }
} }
} }
...@@ -55,19 +65,26 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac ...@@ -55,19 +65,26 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac
if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) {
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq; int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int new_token_offset = batch_idx * max_ctx_len;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) { int num_elts_k = h_kv * d_k;
int num_elts_v = h_kv * d_v;
int hd = h_kv * max(d_k, d_v);
for (int i = blockIdx.y; i < new_len; i += gridDim.y) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size]; int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; dtype *new_token_id_k = new_k + (batch_idx * max_ctx_len + i) * num_elts_k;
for (int j = 0; j < h_kv * d_k; j++) { dtype *new_token_id_v = new_v + (batch_idx * max_ctx_len + i) * num_elts_v;
*(k_cache + token_idx * h_kv * d_k + j) = dtype *token_id_k =
*(new_k + (new_token_offset + i) * h_kv * d_k + j); k_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_k;
} dtype *token_id_v =
for (int j = 0; j < h_kv * d_v; j++) { v_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_v;
*(v_cache + token_idx * h_kv * d_v + j) = for (int j = threadIdx.x; j < hd; j += blockDim.x) {
*(new_v + (new_token_offset + i) * h_kv * d_v + j); if (j < num_elts_k) {
*(token_id_k + j) = *(new_token_id_k + j);
}
if (j < num_elts_v) {
*(token_id_v + j) = *(new_token_id_v + j);
}
} }
} }
} }
...@@ -76,14 +93,24 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac ...@@ -76,14 +93,24 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq; int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) { int num_elts_k = h_kv * d_k;
int num_elts_v = h_kv * d_v;
int hd = h_kv * max(d_k, d_v);
for (int i = blockIdx.y; i < new_len; i += gridDim.y) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size]; int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; dtype *new_token_id_k = new_k + (i * b + batch_idx) * num_elts_k;
for (int j = 0; j < h_kv * d_k; j++) { dtype *new_token_id_v = new_v + (i * b + batch_idx) * num_elts_v;
*(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k + j); dtype *token_id_k =
} k_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_k;
for (int j = 0; j < h_kv * d_v; j++) { dtype *token_id_v =
*(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (i * b + batch_idx) * h_kv * d_v + j); v_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_v;
for (int j = threadIdx.x; j < hd; j += blockDim.x) {
if (j < num_elts_k) {
*(token_id_k + j) = *(new_token_id_k + j);
}
if (j < num_elts_v) {
*(token_id_v + j) = *(new_token_id_v + j);
}
} }
} }
} }
...@@ -92,16 +119,24 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac ...@@ -92,16 +119,24 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq; int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) { int num_elts_k = h_kv * d_k;
int num_elts_v = h_kv * d_v;
int hd = h_kv * max(d_k, d_v);
for (int i = blockIdx.y; i < new_len; i += gridDim.y) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size]; int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; dtype *new_token_id_k = new_k + (cu_new_lens[batch_idx] + i) * num_elts_k;
for (int j = 0; j < h_kv * d_k; j++) { dtype *new_token_id_v = new_v + (cu_new_lens[batch_idx] + i) * num_elts_v;
*(k_cache + token_idx * h_kv * d_k + j) = dtype *token_id_k =
*(new_k + (cu_new_lens[batch_idx] + i) * h_kv * d_k + j); k_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_k;
} dtype *token_id_v =
for (int j = 0; j < h_kv * d_v; j++) { v_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_v;
*(v_cache + token_idx * h_kv * d_v + j) = for (int j = threadIdx.x; j < hd; j += blockDim.x) {
*(new_v + (cu_new_lens[batch_idx] + i) * h_kv * d_v + j); if (j < num_elts_k) {
*(token_id_k + j) = *(new_token_id_k + j);
}
if (j < num_elts_v) {
*(token_id_v + j) = *(new_token_id_v + j);
}
} }
} }
} }
...@@ -116,14 +151,15 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso ...@@ -116,14 +151,15 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso
bool is_non_paged, cudaStream_t stream) { bool is_non_paged, cudaStream_t stream) {
if (new_k.has_data() && new_v.has_data() && k_cache.has_data() && v_cache.has_data()) { if (new_k.has_data() && new_v.has_data() && k_cache.has_data() && v_cache.has_data()) {
if (is_non_paged) { if (is_non_paged) {
reindex_kv_cache_kernel<<<16, 256, 0, stream>>>( reindex_kv_cache_kernel<<<max_seq_len, block_size, 0, stream>>>(
reinterpret_cast<dtype *>(k_cache.data.dptr), reinterpret_cast<dtype *>(k_cache.data.dptr),
reinterpret_cast<dtype *>(v_cache.data.dptr), reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr), reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr), reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len); reinterpret_cast<int *>(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len);
} }
copy_to_kv_cache_kernel<<<16, 256, 0, stream>>>( dim3 grid_size(b, max_ctx_len);
copy_to_kv_cache_kernel<<<grid_size, block_size, 0, stream>>>(
reinterpret_cast<dtype *>(new_k.data.dptr), reinterpret_cast<dtype *>(new_v.data.dptr), reinterpret_cast<dtype *>(new_k.data.dptr), reinterpret_cast<dtype *>(new_v.data.dptr),
reinterpret_cast<dtype *>(k_cache.data.dptr), reinterpret_cast<dtype *>(v_cache.data.dptr), reinterpret_cast<dtype *>(k_cache.data.dptr), reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr), reinterpret_cast<int *>(page_table.data.dptr),
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <assert.h>
#include <cooperative_groups.h>
#include <cuda_runtime.h>
#include <transformer_engine/fused_router.h>
#include "../common.h"
#include "../util/logging.h"
#include "../utils.cuh"
#include "common/util/cuda_runtime.h"
#include "utils.h"
namespace transformer_engine {
// Using Double to hanld all the calculations
using CompType = double;
template <typename DataType, typename IndexType>
__global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
const IndexType* tokens_per_expert,
int total_num_tokens, int num_experts,
int num_rows, int num_cols, int topk, float coeff,
DataType* aux_loss, float* Const_buf) {
#if __CUDA_ARCH__ >= 900
// Using cooperative_groups to manage the cluster
namespace cg = cooperative_groups;
cg::cluster_group cluster = cg::this_cluster();
int thread_id = cg::this_grid().thread_rank();
int lane_id = thread_id % kThreadsPerWarp;
int warp_id = thread_id / kThreadsPerWarp;
int warp_num = blockDim.x * gridDim.x / kThreadsPerWarp;
// Only 1 block in the cluster
int block_id = cluster.block_rank();
int block_num = cluster.dim_blocks().x;
int cluster_id = blockIdx.x / block_num;
if (cluster_id > 0) return; // Only use the cluster 0
extern __shared__ float shmem_aux_loss[];
CompType* aggregated_probs_per_expert = reinterpret_cast<CompType*>(shmem_aux_loss);
// Clear the shmem
for (int i = threadIdx.x; i < num_cols; i += blockDim.x) {
aggregated_probs_per_expert[i] = CompType(0);
}
__syncthreads();
/**
* Section: Reduce the probs to the aggregated_probs_per_expert
* 1. reduce on the block
* 2. reduce on the cluster
*/
// Loop: for all positions in each row
for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) {
CompType tmp = CompType(0);
// Loop: for all rows that this warp is responsible for
for (int j = warp_id; j < num_rows; j += warp_num) {
tmp += CompType(probs[j * num_cols + i]);
}
atomicAdd(&aggregated_probs_per_expert[i], tmp);
}
cluster.sync();
// The block 0 will reduce the results of all blocks
if (block_id == 0) {
for (int i = 1; i < block_num; i++) {
// Map the shared memory of the block i to the current block
CompType* dst_smem = reinterpret_cast<CompType*>(cluster.map_shared_rank(shmem_aux_loss, i));
for (int j = threadIdx.x; j < num_cols; j += blockDim.x) {
atomicAdd(&aggregated_probs_per_expert[j], dst_smem[j]);
}
}
}
cluster.sync();
/**
* Section: aggregated_probs_per_expert * tokens_per_expert
* In-place update on shmem
*/
if (block_id == 0) {
for (int i = threadIdx.x; i < num_cols; i += blockDim.x) {
aggregated_probs_per_expert[i] *= CompType(tokens_per_expert[i]);
}
__syncthreads();
if (warp_id == 0) {
/**
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
CompType intermediate_result =
warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id);
__syncwarp();
if (lane_id == 0) {
/**
* Section: Compute the aux_loss
*/
float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens;
aux_loss[0] = static_cast<DataType>(static_cast<double>(intermediate_result) * C_coeff);
Const_buf[0] = C_coeff;
}
}
}
#else
// Use Only 1 block/1024 threads to avoid the grid sync
if (blockIdx.x > 0) return;
int warp_num = blockDim.x / kThreadsPerWarp;
int warp_id = threadIdx.x / kThreadsPerWarp;
int lane_id = threadIdx.x % kThreadsPerWarp;
extern __shared__ float shmem_aux_loss[];
CompType* aggregated_probs_per_expert = reinterpret_cast<CompType*>(shmem_aux_loss);
// Clear the shmem
for (int i = threadIdx.x; i < num_cols; i += blockDim.x) {
aggregated_probs_per_expert[i] = CompType(0);
}
__syncthreads();
/**
* Section: Reduce the probs to the aggregated_probs_per_expert
*/
// Loop: for all positions in each row
for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) {
CompType tmp = CompType(0);
// Loop: for all rows that this warp is responsible for
for (int j = warp_id; j < num_rows; j += warp_num) {
tmp += CompType(probs[j * num_cols + i]);
}
atomicAdd(&aggregated_probs_per_expert[i], tmp);
}
__syncthreads();
/**
* Section: aggregated_probs_per_expert * tokens_per_expert
* In-place update on shmem
*/
for (int i = threadIdx.x; i < num_cols; i += blockDim.x) {
aggregated_probs_per_expert[i] *= CompType(tokens_per_expert[i]);
}
__syncthreads();
if (warp_id == 0) {
/**
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
CompType intermediate_result =
warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id);
__syncwarp();
if (lane_id == 0) {
/**
* Section: Compute the aux_loss
*/
float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens;
aux_loss[0] = static_cast<DataType>(static_cast<double>(intermediate_result) * C_coeff);
Const_buf[0] = C_coeff;
}
}
#endif
}
template <typename DataType, typename IndexType>
void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
const IndexType* tokens_per_expert,
int total_num_tokens, int num_experts, int num_rows,
int num_cols, int topk, float coeff,
DataType* aux_loss, float* Const_buf,
cudaStream_t stream) {
if (cuda::sm_arch(cuda::current_device()) >= 90) {
cudaLaunchConfig_t config = {0};
int cluster_size = 8;
config.gridDim = cluster_size;
config.blockDim = 1024;
config.dynamicSmemBytes = sizeof(CompType) * num_cols;
config.stream = stream;
// Update the max cluster size based on the device
cudaOccupancyMaxPotentialClusterSize(
&cluster_size,
reinterpret_cast<void*>(fused_moe_aux_loss_forward_kernel<DataType, IndexType>), &config);
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeClusterDimension;
attribute[0].val.clusterDim.x = cluster_size;
attribute[0].val.clusterDim.y = 1;
attribute[0].val.clusterDim.z = 1;
config.numAttrs = 1;
config.attrs = attribute;
cudaLaunchKernelEx(&config, fused_moe_aux_loss_forward_kernel<DataType, IndexType>, probs,
tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk,
coeff, aux_loss, Const_buf);
} else {
size_t smem_size = sizeof(CompType) * num_cols;
fused_moe_aux_loss_forward_kernel<DataType, IndexType>
<<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts,
num_rows, num_cols, topk, coeff, aux_loss, Const_buf);
}
}
void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_expert,
int total_num_tokens, int num_experts, int num_rows, int num_cols,
int topk, float coeff, Tensor& aux_loss, Tensor& Const_buf,
cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
probs.data.dtype, DataType,
TE_ROUTER_INDEX_TYPE_SWITCH_ALL(
tokens_per_expert.data.dtype, IndexType,
fused_moe_aux_loss_forward_kernel_launcher<DataType, IndexType>(
reinterpret_cast<DataType*>(probs.data.dptr),
reinterpret_cast<IndexType*>(tokens_per_expert.data.dptr), total_num_tokens,
num_experts, num_rows, num_cols, topk, coeff,
reinterpret_cast<DataType*>(aux_loss.data.dptr),
reinterpret_cast<float*>(Const_buf.data.dptr), stream);););
}
template <typename DataType, typename IndexType>
__global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf,
const IndexType* tokens_per_expert, int num_rows,
int num_cols, DataType* grad_aux_loss,
DataType* grad_probs) {
int global_warp_num = gridDim.x * blockDim.x / kThreadsPerWarp;
int global_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kThreadsPerWarp;
int lane_id = threadIdx.x % kThreadsPerWarp;
// Loop: for all positions in each row
for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) {
float C_coeff = Const_buf[0];
IndexType tokens_per_expert_i = tokens_per_expert[i];
double grad_aux_loss_value = static_cast<double>(grad_aux_loss[0]);
// Loop: for all rows
for (int j = global_warp_id; j < num_rows; j += global_warp_num) {
grad_probs[j * num_cols + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value;
}
}
}
template <typename DataType, typename IndexType>
void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf,
const IndexType* tokens_per_expert, int num_rows,
int num_cols, DataType* grad_aux_loss,
DataType* grad_probs, cudaStream_t stream) {
// Meta data for the kernel
int block_size = 256;
int grid_size = (num_rows + block_size - 1) / block_size;
fused_moe_aux_loss_backward_kernel<DataType, IndexType><<<grid_size, block_size, 0, stream>>>(
Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss, grad_probs);
}
void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_per_expert,
int num_rows, int num_cols, Tensor& grad_aux_loss,
Tensor& grad_probs, cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
grad_aux_loss.data.dtype, DataType,
TE_ROUTER_INDEX_TYPE_SWITCH_ALL(
tokens_per_expert.data.dtype, IndexType,
fused_moe_aux_loss_backward_kernel_launcher<DataType, IndexType>(
reinterpret_cast<float*>(Const_buf.data.dptr),
reinterpret_cast<IndexType*>(tokens_per_expert.data.dptr), num_rows, num_cols,
reinterpret_cast<DataType*>(grad_aux_loss.data.dptr),
reinterpret_cast<DataType*>(grad_probs.data.dptr), stream);););
}
} // namespace transformer_engine
void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert,
int total_num_tokens, int num_experts, int num_rows,
int num_cols, int topk, float coeff, NVTETensor aux_loss,
NVTETensor Const_buf, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_moe_aux_loss_forward);
using namespace transformer_engine;
fused_moe_aux_loss_forward(
*convertNVTETensorCheck(probs), *convertNVTETensorCheck(tokens_per_expert), total_num_tokens,
num_experts, num_rows, num_cols, topk, coeff, *convertNVTETensorCheck(aux_loss),
*convertNVTETensorCheck(Const_buf), stream);
}
void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf,
const NVTETensor tokens_per_expert, int num_rows,
int num_cols, NVTETensor grad_aux_loss, NVTETensor grad_probs,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_moe_aux_loss_backward);
using namespace transformer_engine;
fused_moe_aux_loss_backward(*convertNVTETensorCheck(Const_buf),
*convertNVTETensorCheck(tokens_per_expert), num_rows, num_cols,
*convertNVTETensorCheck(grad_aux_loss),
*convertNVTETensorCheck(grad_probs), stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <assert.h>
#include <cuda_runtime.h>
#include <transformer_engine/fused_router.h>
#include "../common.h"
#include "../util/logging.h"
#include "../utils.cuh"
#include "utils.h"
namespace transformer_engine {
template <typename DataType>
__global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logits, int num_tokens,
int num_experts, int topk,
int score_function, DataType *scores,
bool *routing_map,
DataType *intermediate_output) {
/***
* Section: Global Variables/Addresses init
* - Assume the sizeof(DataType) >= sizeof(int),
* So DataType address is assigned firstly to avoid the alignment issue
* - Each warp is responsible for one token, and has own shared memory buffer.
* Then __syncwarp() is used instead of __syncthreads()
*/
// Used variables/addresses init
int num_token_per_block = blockDim.x / kThreadsPerWarp;
int warp_id = threadIdx.x / kThreadsPerWarp;
int lane_id = threadIdx.x % kThreadsPerWarp;
extern __shared__ float shmem_scores_for_aux_loss[];
DataType *logits_buf = reinterpret_cast<DataType *>(shmem_scores_for_aux_loss);
DataType *topk_logits_buf =
reinterpret_cast<DataType *>(logits_buf + num_experts * num_token_per_block);
int *topk_indices_buf = reinterpret_cast<int *>(topk_logits_buf + topk * num_token_per_block);
// The address of buffers on the current warp
DataType *local_logits = logits_buf + warp_id * num_experts;
DataType *topk_logits = topk_logits_buf + warp_id * topk;
int *topk_indices = topk_indices_buf + warp_id * topk;
/***
* Section: Main Loop
* - Each warp is responsible for one token
*/
int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block;
for (int round = blockIdx.x; round < total_round; round += gridDim.x) {
int token_offset_cur_warp = round * num_token_per_block + warp_id;
// Each warp is responsible for one token
if (token_offset_cur_warp >= num_tokens) break;
/***
* Section: Init buffer
* - Clear the global buffer which will accept the result of this round
* - Clear/Init the shmem buffer used by current warp this round
* - Load the logits to shmem
*/
int pos_offset = token_offset_cur_warp * num_experts;
// Clear the routing_map (num_experts)
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
routing_map[pos_offset + i] = false;
if (score_function == 1) {
intermediate_output[pos_offset + i] = -std::numeric_limits<DataType>::infinity();
}
}
// Load the logits to shmem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_logits[i] = logits[pos_offset + i];
}
__threadfence_block();
__syncwarp();
/***
* Section: Preprocess
* Possible preprocess the scores before the topk operation
* - Pre-softmax
* - Sigmoid
* - Sigmoid post-processing when topk > 1
* This is in-place scores update
*/
// score_function == 1 means softmax
if (score_function == 1) {
// Apply softmax to the logits before the topk
apply_softmax_on_float(local_logits, num_experts, lane_id);
__syncwarp();
// Save the softmax output for backward
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
intermediate_output[pos_offset + i] = local_logits[i];
}
}
// score_function == 0 means sigmoid
if (score_function == 0) {
// Apply sigmoid to the logits
apply_sigmoid_on_float(local_logits, num_experts, lane_id);
__syncwarp();
// Save the sigmoid output for backward
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
intermediate_output[pos_offset + i] = local_logits[i];
}
}
__syncwarp(); //Confirm the scores is written to the softmax/sigmoid output
if (score_function == 0) {
if (topk > 1) {
auto sum_logits = warp_reduce_on_shmem(local_logits, num_experts, sum, lane_id);
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_logits[i] = static_cast<DataType>(static_cast<double>(local_logits[i]) /
(static_cast<double>(sum_logits) + epsilon));
}
}
__syncwarp();
}
/***
* Section: Topk
* Get the topk indices
*/
naive_topk_and_mask(local_logits, num_experts, topk, topk_indices, topk_logits, lane_id);
__syncwarp();
// Write the routing_map to the output tensor
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
routing_map[pos_offset + topk_indices[i]] = true;
}
// Write the scores to the output tensor
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
scores[pos_offset + i] = local_logits[i];
}
__threadfence_block();
__syncwarp();
}
}
template <typename DataType>
void fused_score_for_moe_aux_loss_forward_kernel_launcher(
const DataType *logits, int num_tokens, int num_experts, int topk, int score_function,
DataType *scores, bool *routing_map, DataType *intermediate_output, cudaStream_t stream) {
// Meta data for the kernel
size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp;
size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block;
size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // logits
+ topk * num_token_per_block * sizeof(DataType) // topk_logits
+ topk * num_token_per_block * sizeof(int); // topk_indices
fused_score_for_moe_aux_loss_forward_kernel<DataType>
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
logits, num_tokens, num_experts, topk, score_function, scores, routing_map,
intermediate_output);
}
void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, int num_experts,
int topk, int score_function, Tensor &scores,
Tensor &routing_map, Tensor &intermediate_output,
cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
logits.data.dtype, DataType,
fused_score_for_moe_aux_loss_forward_kernel_launcher<DataType>(
reinterpret_cast<DataType *>(logits.data.dptr), num_tokens, num_experts, topk,
score_function, reinterpret_cast<DataType *>(scores.data.dptr),
reinterpret_cast<bool *>(routing_map.data.dptr),
reinterpret_cast<DataType *>(intermediate_output.data.dptr), stream););
}
template <typename DataType>
__global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *intermediate_output,
const DataType *grad_scores,
int num_tokens, int num_experts,
int topk, int score_function,
DataType *grad_logits) {
/***
* Section: Global Variables/Addresses init
* - Assume the sizeof(DataType) >= sizeof(int),
* - Each warp is responsible for one token, and has own shared memory buffer.
* Then __syncwarp() is used instead of __syncthreads()
*/
// Used variables/addresses init
int num_token_per_block = blockDim.x / kThreadsPerWarp;
int warp_id = threadIdx.x / kThreadsPerWarp;
int lane_id = threadIdx.x % kThreadsPerWarp;
extern __shared__ float shmem[];
DataType *grad_scores_buf = reinterpret_cast<DataType *>(shmem);
// To store the output of softmax/sigmoid from the fwd
DataType *act_from_fwd_buf =
reinterpret_cast<DataType *>(grad_scores_buf + num_experts * num_token_per_block);
DataType *comp_buf =
reinterpret_cast<DataType *>(act_from_fwd_buf + num_experts * num_token_per_block);
// The address of buffers on the current warp
DataType *local_grad = grad_scores_buf + warp_id * num_experts;
DataType *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts;
DataType *local_comp_buf = comp_buf + warp_id * num_experts;
/***
* Section: Main Loop
* - Each warp is responsible for one token
*/
int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block;
for (int round = blockIdx.x; round < total_round; round += gridDim.x) {
int token_offset_cur_warp = round * num_token_per_block + warp_id;
// Each warp is responsible for one token
if (token_offset_cur_warp >= num_tokens) break;
/***
* Section: Init buffer
* - Clear the global buffer which will accept the result of this round
* - Clear/Init the shmem buffer used by current warp this round
* - Load the dgrad/output_from_fwd to shmem
*/
int pos_offset = token_offset_cur_warp * num_experts;
// Clear the logits_grad in global mem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
grad_logits[pos_offset + i] = 0.0f;
}
// Load the dgrad/output_from_fwd to shmem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_grad[i] = grad_scores[pos_offset + i];
local_act_from_fwd[i] = intermediate_output[pos_offset + i];
}
__threadfence_block();
__syncwarp();
/***
* Section: Backward of ops before the topk
* - Pre-softmax bwd
* - Sigmoid Post-processing bwd when topk > 1
* - Sigmoid bwd
* - Write the grad_logits to the global mem
*/
// Sigmoid Post-processing bwd when topk > 1
if (topk > 1 && score_function == 0) {
auto sum_fwd_input = warp_reduce_on_shmem(local_act_from_fwd, num_experts, sum, lane_id);
// Put the result of output * grad to the comp_buf
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_comp_buf[i] = local_grad[i] * local_act_from_fwd[i];
}
__syncwarp();
auto sum_Output_x_Grad = warp_reduce_on_shmem(local_comp_buf, num_experts, sum, lane_id);
// In-place update
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_grad[i] =
static_cast<double>(local_grad[i]) / (static_cast<double>(sum_fwd_input) + epsilon) -
static_cast<double>(sum_Output_x_Grad) /
((static_cast<double>(sum_fwd_input) + epsilon) *
(static_cast<double>(sum_fwd_input) + epsilon));
}
}
__syncwarp();
// Pre-softmax bwd
if (score_function == 1) {
apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, nullptr,
num_experts, lane_id);
__syncwarp();
}
// Sigmoid bwd
if (score_function == 0) {
apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id);
__syncwarp();
}
// Write the grad_logits to the global mem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
grad_logits[pos_offset + i] = local_grad[i];
}
__syncwarp();
}
}
template <typename DataType>
void fused_score_for_moe_aux_loss_backward_kernel_launcher(
const DataType *intermediate_output, const DataType *grad_scores, int num_tokens,
int num_experts, int topk, int score_function, DataType *grad_logits, cudaStream_t stream) {
// Meta data for the kernel
size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp;
size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block;
size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // grad_scores
+
num_experts * num_token_per_block * sizeof(DataType) // act_from_fwd
+ num_experts * num_token_per_block * sizeof(DataType); // comp_buf
fused_score_for_moe_aux_loss_backward_kernel<DataType>
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function,
grad_logits);
}
void fused_score_for_moe_aux_loss_backward(const Tensor &intermediate_output,
const Tensor &grad_scores, int num_tokens,
int num_experts, int topk, int score_function,
Tensor &grad_logits, cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
grad_scores.data.dtype, DataType,
fused_score_for_moe_aux_loss_backward_kernel_launcher<DataType>(
reinterpret_cast<DataType *>(intermediate_output.data.dptr),
reinterpret_cast<DataType *>(grad_scores.data.dptr), num_tokens, num_experts, topk,
score_function, reinterpret_cast<DataType *>(grad_logits.data.dptr), stream););
}
} // namespace transformer_engine
void nvte_fused_score_for_moe_aux_loss_forward(const NVTETensor logits, int num_tokens,
int num_experts, int topk, int score_function,
NVTETensor scores, const NVTETensor routing_map,
const NVTETensor intermediate_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_score_for_moe_aux_loss_forward);
using namespace transformer_engine;
fused_score_for_moe_aux_loss_forward(*convertNVTETensorCheck(logits), num_tokens, num_experts,
topk, score_function, *convertNVTETensorCheck(scores),
*convertNVTETensorCheck(routing_map),
*convertNVTETensorCheck(intermediate_output), stream);
}
void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_output,
const NVTETensor grad_scores, int num_tokens,
int num_experts, int topk, int score_function,
NVTETensor grad_logits, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_score_for_moe_aux_loss_backward);
using namespace transformer_engine;
fused_score_for_moe_aux_loss_backward(
*convertNVTETensorCheck(intermediate_output), *convertNVTETensorCheck(grad_scores),
num_tokens, num_experts, topk, score_function, *convertNVTETensorCheck(grad_logits), stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <assert.h>
#include <cuda_runtime.h>
#include <transformer_engine/fused_router.h>
#include "../common.h"
#include "../util/logging.h"
#include "../utils.cuh"
#include "utils.h"
namespace transformer_engine {
template <typename DataType, typename BiasType>
__global__ void fused_topk_with_score_function_forward_kernel(
const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax,
int num_groups, int group_topk, float scaling_factor, int score_function,
const BiasType *expert_bias, DataType *probs, bool *routing_map,
DataType *intermediate_output) {
/***
* Section: Global Variables/Addresses init
* - Assume the sizeof(DataType) >= sizeof(int),
* So DataType address is assigned firstly to avoid the alignment issue
* - Each warp is responsible for one token, and has own shared memory buffer.
* Then __syncwarp() is used instead of __syncthreads()
*/
// Used variables/addresses init
int num_token_per_block = blockDim.x / kThreadsPerWarp;
int warp_id = threadIdx.x / kThreadsPerWarp;
int lane_id = threadIdx.x % kThreadsPerWarp;
extern __shared__ float shmem[];
DataType *scores_buf = reinterpret_cast<DataType *>(shmem);
DataType *topk_scores_buf =
reinterpret_cast<DataType *>(scores_buf + num_experts * num_token_per_block);
DataType *group_scores_buf = nullptr, *masked_scores_buf = nullptr;
int *topk_indices_buf = nullptr;
if (group_topk > 0) {
masked_scores_buf = reinterpret_cast<DataType *>(topk_scores_buf + topk * num_token_per_block);
group_scores_buf =
reinterpret_cast<DataType *>(masked_scores_buf + num_experts * num_token_per_block);
topk_indices_buf = reinterpret_cast<int *>(group_scores_buf + num_groups * num_token_per_block);
} else {
topk_indices_buf = reinterpret_cast<int *>(topk_scores_buf + topk * num_token_per_block);
}
// The address of buffers on the current warp
DataType *scores = scores_buf + warp_id * num_experts;
DataType *topk_scores = topk_scores_buf + warp_id * topk;
DataType *masked_scores = masked_scores_buf + warp_id * num_experts;
DataType *group_scores = group_scores_buf + warp_id * num_groups;
int *topk_indices = topk_indices_buf + warp_id * topk;
/***
* Section: Main Loop
* - Each warp is responsible for one token
*/
int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block;
for (int round = blockIdx.x; round < total_round; round += gridDim.x) {
int token_offset_cur_warp = round * num_token_per_block + warp_id;
// Each warp is responsible for one token
if (token_offset_cur_warp >= num_tokens) break;
/***
* Section: Init buffer
* - Clear the global buffer which will accept the result of this round
* - Clear/Init the shmem buffer used by current warp this round
* - Load the logits to shmem
*/
int pos_offset = token_offset_cur_warp * num_experts;
// Clear the probs/routing_map (num_experts)
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
probs[pos_offset + i] = 0.0f;
routing_map[pos_offset + i] = false;
if (score_function == 1) {
intermediate_output[pos_offset + i] = -std::numeric_limits<DataType>::infinity();
}
}
// Load the logits to shmem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
scores[i] = logits[pos_offset + i];
}
// If group_topk > 0, init the masked_scores to -inf
if (group_topk > 0) {
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
masked_scores[i] = -std::numeric_limits<DataType>::infinity();
}
}
__threadfence_block();
__syncwarp();
/***
* Section: Preprocess
* Possible preprocess the scores before the topk operation
* - Pre-softmax
* - Sigmoid
* - Expert bias
* This is in-place scores update
*/
// score_function == 1 means softmax
if (use_pre_softmax && score_function == 1) {
// Apply softmax to the logits before the topk
apply_softmax_on_float(scores, num_experts, lane_id);
__syncwarp();
// Save the softmax output for backward
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
intermediate_output[pos_offset + i] = scores[i];
}
}
// score_function == 0 means sigmoid
if (score_function == 0) {
// Apply sigmoid to the logits
apply_sigmoid_on_float(scores, num_experts, lane_id);
__syncwarp();
// Save the sigmoid output for backward
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
intermediate_output[pos_offset + i] = scores[i];
}
}
__syncwarp(); //Confirm the scores is written to the softmax/sigmoid output
// Expert bias is only used at the sigmoid case
if (expert_bias && score_function == 0) {
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
scores[i] = static_cast<DataType>(static_cast<double>(scores[i]) +
static_cast<double>(expert_bias[i]));
}
}
__syncwarp();
/***
* Section: Topk
* Get the topk indices
* - group_topk
* - naive topk
* - topk with expert bias
*/
// Topk on the scores
// The bias is not empty only happens at the sigmod case
if (group_topk > 0) {
int group_size = num_experts / num_groups;
// Top2
for (int i = 0; i < num_groups; i++) {
naive_topk_and_mask(
/*scores ptr = */ scores + i * group_size,
/*data size = */ group_size,
/*topk = */ topk / group_topk,
/*topk indices ptr = */ topk_indices,
/*topk scores ptr = */ topk_scores,
/*lane id = */ lane_id);
__syncwarp();
// Compute the group score
if (lane_id == 0) {
DataType tmp = 0.0f;
for (int j = 0; j < topk / group_topk; j++) {
tmp = tmp + topk_scores[j];
}
group_scores[i] = tmp;
}
__syncwarp();
}
// select the topk groups
naive_topk_and_mask(
/*scores ptr = */ group_scores,
/*data size = */ num_groups,
/*topk = */ group_topk,
/*topk indices ptr = */ topk_indices,
/*topk scores ptr = */ topk_scores,
/*lane id = */ lane_id);
__syncwarp();
// Copy the unmasked scores to the buffer
for (int i = 0; i < group_topk; i++) {
int st = topk_indices[i] * group_size;
int ed = st + group_size;
for (int j = st + lane_id; j < ed; j += kThreadsPerWarp) {
masked_scores[j] = scores[j];
}
}
__syncwarp();
naive_topk_and_mask(masked_scores, num_experts, topk, topk_indices, topk_scores, lane_id);
} else {
naive_topk_and_mask(scores, num_experts, topk, topk_indices, topk_scores, lane_id);
}
__syncwarp();
/***
* Section: Postprocess
* Possible postprocess the scores after the topk operation
* - Revert Expert bias
* - Softmax
* - Sigmoid post-processing when topk > 1
* - Write the result with scaling_factor
*/
// Revert Expert bias from the topk scores
if (expert_bias && score_function == 0) {
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
topk_scores[i] =
static_cast<double>(topk_scores[i]) - static_cast<double>(expert_bias[topk_indices[i]]);
}
}
__syncwarp();
// score_function == 1 means softmax
if (!use_pre_softmax && score_function == 1) {
// Apply softmax to the topk logits
apply_softmax_on_float(topk_scores, topk, lane_id);
__syncwarp();
// Save the softmax output for backward
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
intermediate_output[pos_offset + topk_indices[i]] = topk_scores[i];
}
}
// score_function == 0 means sigmoid
if (score_function == 0) {
if (topk > 1) {
double sum_scores = warp_reduce_on_shmem(topk_scores, topk, sum, lane_id);
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
topk_scores[i] = static_cast<double>(topk_scores[i]) / (sum_scores + epsilon);
}
}
__syncwarp();
}
// Write the probs/routing_map to the output tensor
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
routing_map[pos_offset + topk_indices[i]] = true;
probs[pos_offset + topk_indices[i]] = scaling_factor * static_cast<double>(topk_scores[i]);
}
__threadfence_block();
__syncwarp();
}
}
template <typename DataType, typename BiasType>
void fused_topk_with_score_function_forward_kernel_launcher(
const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax,
int num_groups, int group_topk, float scaling_factor, int score_function,
const BiasType *expert_bias, DataType *probs, bool *routing_map, DataType *intermediate_output,
cudaStream_t stream) {
size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp;
size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block;
size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // scores
+ topk * num_token_per_block * sizeof(DataType) // topk_scores
+ topk * num_token_per_block * sizeof(int); // topk_indices
if (group_topk > 0) {
shared_memory_size += num_groups * num_token_per_block * sizeof(DataType); // group_scores
shared_memory_size += num_experts * num_token_per_block * sizeof(DataType); // maksed_scores
}
fused_topk_with_score_function_forward_kernel<DataType, BiasType>
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk,
scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output);
}
void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, int num_experts,
int topk, bool use_pre_softmax, int num_groups,
int group_topk, float scaling_factor,
int score_function, const Tensor expert_bias,
Tensor probs, Tensor routing_map,
Tensor intermediate_output, cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
logits.data.dtype, DataType,
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
expert_bias.data.dtype, BiasType,
fused_topk_with_score_function_forward_kernel_launcher<DataType, BiasType>(
reinterpret_cast<DataType *>(logits.data.dptr), num_tokens, num_experts, topk,
use_pre_softmax, num_groups, group_topk, scaling_factor, score_function,
reinterpret_cast<BiasType *>(expert_bias.data.dptr),
reinterpret_cast<DataType *>(probs.data.dptr),
reinterpret_cast<bool *>(routing_map.data.dptr),
reinterpret_cast<DataType *>(intermediate_output.data.dptr), stream);););
}
template <typename DataType>
__global__ void fused_topk_with_score_function_backward_kernel(
// Inputs tensor
const bool *routing_map, const DataType *intermediate_output, const DataType *grad_probs,
// Other parameters
int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor,
int score_function,
// Output tensor
DataType *grad_logits) {
/***
* Section: Global Variables/Addresses init
* - Assume the sizeof(DataType) >= sizeof(int),
* - Each warp is responsible for one token, and has own shared memory buffer.
* Then __syncwarp() is used instead of __syncthreads()
*/
// Used variables/addresses init
int num_token_per_block = blockDim.x / kThreadsPerWarp;
int warp_id = threadIdx.x / kThreadsPerWarp;
int lane_id = threadIdx.x % kThreadsPerWarp;
extern __shared__ float shmem[];
DataType *grad_probs_buf = reinterpret_cast<DataType *>(shmem);
// To store the output of softmax/sigmoid from the fwd
DataType *act_from_fwd_buf =
reinterpret_cast<DataType *>(grad_probs_buf + num_experts * num_token_per_block);
DataType *comp_buf =
reinterpret_cast<DataType *>(act_from_fwd_buf + num_experts * num_token_per_block);
// To store the routing_map from the fwd
bool *routing_map_buf = reinterpret_cast<bool *>(comp_buf + num_experts * num_token_per_block);
// The address of buffers on the current warp
DataType *local_grad = grad_probs_buf + warp_id * num_experts;
DataType *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts;
DataType *local_comp_buf = comp_buf + warp_id * num_experts;
bool *local_routing_map = routing_map_buf + warp_id * num_experts;
/***
* Section: Main Loop
* - Each warp is responsible for one token
*/
int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block;
for (int round = blockIdx.x; round < total_round; round += gridDim.x) {
int token_offset_cur_warp = round * num_token_per_block + warp_id;
// Each warp is responsible for one token
if (token_offset_cur_warp >= num_tokens) break;
/***
* Section: Init buffer
* - Clear the global buffer which will accept the result of this round
* - Clear/Init the shmem buffer used by current warp this round
* - Load the dgrad/output_from_fwd to shmem
*/
int pos_offset = token_offset_cur_warp * num_experts;
// Clear the logits_grad in global mem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
grad_logits[pos_offset + i] = 0.0f;
}
// Load the dgrad/output_from_fwd to shmem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_grad[i] = grad_probs[pos_offset + i];
local_act_from_fwd[i] = intermediate_output[pos_offset + i];
local_routing_map[i] = routing_map[pos_offset + i];
}
__threadfence_block();
__syncwarp();
/***
* Section: Backward of ops after the topk
* - Backward of the used scaling_factor
* - Sigmoid Post-processing bwd when topk > 1
* - Softmax bwd if use_pre_softmax is false
*/
// Backward of the used scaling_factor
// In-place update
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
if (local_routing_map[i]) {
local_grad[i] = static_cast<double>(local_grad[i]) * scaling_factor;
}
}
__syncwarp();
// Sigmoid Post-processing bwd when topk > 1
if (topk > 1 && score_function == 0) {
double sum_fwd_input = masked_warp_reduce_on_shmem(
/*data ptr = */ local_act_from_fwd,
/*mask ptr = */ local_routing_map,
/*data size = */ num_experts,
/*reduce func = */ sum, lane_id);
// Put the result of output * grad to the comp_buf
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_comp_buf[i] = (local_routing_map[i] ? static_cast<double>(local_grad[i]) *
static_cast<double>(local_act_from_fwd[i])
: 0.0f);
}
__syncwarp();
double sum_Output_x_Grad = masked_warp_reduce_on_shmem(
/*data ptr = */ local_comp_buf,
/*mask ptr = */ local_routing_map,
/*data size = */ num_experts,
/*reduce func = */ sum, lane_id);
// In-place update
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
if (local_routing_map[i]) {
local_grad[i] =
static_cast<double>(local_grad[i]) / (sum_fwd_input + epsilon) -
sum_Output_x_Grad / ((sum_fwd_input + epsilon) * (sum_fwd_input + epsilon));
} else {
local_grad[i] = 0.0f;
}
}
}
__syncwarp();
// Softmax bwd if use_pre_softmax is false
if (!use_pre_softmax && score_function == 1) {
apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, local_routing_map,
num_experts, lane_id);
__syncwarp();
}
/***
* Section: Backward of topk
* mask the unselected position in the grad
*/
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
if (!local_routing_map[i]) {
local_grad[i] = 0.0f;
}
}
__syncwarp();
/***
* Section: Backward of ops before the topk
* - Pre-softmax bwd
* - Sigmoid bwd
* - Write the grad_logits to the global mem
*/
// Pre-softmax bwd
if (score_function == 1 && use_pre_softmax) {
apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, nullptr,
num_experts, lane_id);
__syncwarp();
}
// Sigmoid bwd
if (score_function == 0) {
apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id);
__syncwarp();
}
// Write the grad_logits to the global mem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
grad_logits[pos_offset + i] = local_grad[i];
}
__syncwarp();
}
}
template <typename DataType>
void fused_topk_with_score_function_backward_kernel_launcher(
const bool *routing_map, const DataType *intermediate_output, const DataType *grad_probs,
int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor,
int score_function, DataType *grad_logits, cudaStream_t stream) {
// Meta data for the kernel
size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp;
size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block;
size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // grad_probs
+
num_experts * num_token_per_block * sizeof(DataType) // act_from_fwd
+ num_experts * num_token_per_block * sizeof(DataType) // comp_buf
+ num_experts * num_token_per_block * sizeof(bool); // routing_map
fused_topk_with_score_function_backward_kernel<DataType>
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk,
use_pre_softmax, scaling_factor, score_function, grad_logits);
}
void fused_topk_with_score_function_backward(const Tensor &routing_map,
const Tensor &intermediate_output,
const Tensor &grad_probs, int num_tokens,
int num_experts, int topk, bool use_pre_softmax,
float scaling_factor, int score_function,
Tensor &grad_logits, cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
grad_logits.data.dtype, DataType,
fused_topk_with_score_function_backward_kernel_launcher<DataType>(
reinterpret_cast<bool *>(routing_map.data.dptr),
reinterpret_cast<DataType *>(intermediate_output.data.dptr),
reinterpret_cast<DataType *>(grad_probs.data.dptr), num_tokens, num_experts, topk,
use_pre_softmax, scaling_factor, score_function,
reinterpret_cast<DataType *>(grad_logits.data.dptr), stream););
}
} // namespace transformer_engine
void nvte_fused_topk_with_score_function_forward(
const NVTETensor logits, int num_tokens, int num_experts, int topk, int use_pre_softmax,
int num_groups, int group_topk, float scaling_factor, int score_function,
const NVTETensor expert_bias, NVTETensor probs, NVTETensor routing_map,
NVTETensor intermediate_output, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_topk_with_score_function_forward);
using namespace transformer_engine;
fused_topk_with_score_function_forward(
*convertNVTETensorCheck(logits), num_tokens, num_experts, topk,
static_cast<bool>(use_pre_softmax), num_groups, group_topk, scaling_factor, score_function,
*convertNVTETensorCheck(expert_bias), *convertNVTETensorCheck(probs),
*convertNVTETensorCheck(routing_map), *convertNVTETensorCheck(intermediate_output), stream);
}
void nvte_fused_topk_with_score_function_backward(const NVTETensor routing_map,
const NVTETensor intermediate_output,
const NVTETensor grad_probs, int num_tokens,
int num_experts, int topk, int use_pre_softmax,
float scaling_factor, int score_function,
NVTETensor grad_logits, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_topk_with_score_function_backward);
using namespace transformer_engine;
fused_topk_with_score_function_backward(
*convertNVTETensorCheck(routing_map), *convertNVTETensorCheck(intermediate_output),
*convertNVTETensorCheck(grad_probs), num_tokens, num_experts, topk,
static_cast<bool>(use_pre_softmax), scaling_factor, score_function,
*convertNVTETensorCheck(grad_logits), stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
constexpr size_t kThreadsPerWarp = 32;
constexpr int kThreadsPerBlock =
128; // Using 4 warps in 1 CTA, Each warp is responsible for 1 token.
constexpr float epsilon = 1e-20;
template <typename T>
__device__ inline T max(T a, T b) {
return a > b ? a : b;
}
template <typename T>
__device__ inline T sum(T a, T b) {
return a + b;
}
template <typename T>
__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, T (*reduce_func)(T, T),
int lane_id) {
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
volatile double val =
lane_id < data_size ? static_cast<double>(data_ptr[lane_id]) : static_cast<double>(0);
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
val = reduce_func(val, data_ptr[i]);
}
// Warp shuffle between threads
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1));
__syncwarp();
return T(val);
}
template <typename DataType>
__device__ inline void apply_sigmoid_on_float(DataType *scores, int data_size, int lane_id) {
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
scores[i] = static_cast<float>(1.0f / (1.0f + exp(-static_cast<float>(scores[i]))));
}
}
template <typename T>
__device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int data_size,
T (*reduce_func)(T, T), int lane_id) {
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
volatile double val = lane_id < data_size && mask[lane_id]
? static_cast<double>(data_ptr[lane_id])
: static_cast<double>(0);
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
if (mask[i]) {
val = reduce_func(val, data_ptr[i]);
}
}
// Warp shuffle between threads
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1));
__syncwarp();
return T(val);
}
template <typename DataType>
__device__ inline void apply_sigmoid_bwd_on_float(DataType *grad, DataType *fwd_output,
int data_size, int lane_id) {
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
grad[i] = static_cast<double>(grad[i]) * static_cast<double>(fwd_output[i]) *
(1 - static_cast<double>(fwd_output[i]));
}
}
template <typename DataType>
__device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_output,
DataType *comp_buf, bool *mask, int data_size,
int lane_id) {
// Put the result of output * grad to the comp_buf
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
if (mask) {
if (mask[i])
comp_buf[i] = static_cast<float>(grad[i]) * static_cast<float>(fwd_output[i]);
else
comp_buf[i] = 0.0f;
} else {
comp_buf[i] = static_cast<float>(grad[i]) * static_cast<float>(fwd_output[i]);
}
}
__syncwarp();
float sum_Output_x_Grad = warp_reduce_on_shmem(
/*data ptr = */ comp_buf,
/*data size = */ data_size,
/*reduce func = */ sum, lane_id);
// In-place update
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
if (mask) {
if (mask[i])
grad[i] =
static_cast<float>(fwd_output[i]) * (static_cast<float>(grad[i]) - sum_Output_x_Grad);
else
grad[i] = 0.0f;
} else {
grad[i] =
static_cast<float>(fwd_output[i]) * (static_cast<float>(grad[i]) - sum_Output_x_Grad);
}
}
}
template <typename DataType>
__device__ inline void apply_softmax_on_float(DataType *scores, int data_size, int lane_id) {
// 1. compute the max of value
float max_val = static_cast<float>(warp_reduce_on_shmem(scores, data_size, max, lane_id));
// 2. value -> exp_value
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
scores[i] = static_cast<float>(exp(static_cast<float>(scores[i]) - max_val));
}
__syncwarp();
// 3. compute the sum of exp_value
float sum_val = static_cast<float>(warp_reduce_on_shmem(scores, data_size, sum, lane_id));
// 4. update the softmax value
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
scores[i] = static_cast<float>(scores[i]) / sum_val;
}
__syncwarp();
}
template <typename T>
__device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices,
T *topk_scores, int lane_id) {
// Topk Times: Find the max value and its index
// Then mask it, and record the index in the topk_indices
// After looping topk times, the topk_indices will be the topk indices
for (int k = 0; k < topk; k++) {
// Find the max value and its index
volatile double val =
(lane_id < data_size) ? static_cast<double>(scores[lane_id]) : static_cast<double>(0);
volatile int index = (lane_id < data_size) ? lane_id : 0;
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
volatile double cur_val = scores[i];
if (cur_val > val) {
val = cur_val;
index = i;
}
}
// Warp shuffle between threads
for (int s = 16; s > 0; s /= 2) {
volatile auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s);
volatile auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s);
if (shuffled_val > val) {
val = shuffled_val;
index = shuffled_index;
}
}
if (lane_id == 0) {
topk_indices[k] = index;
topk_scores[k] = val;
scores[index] =
static_cast<double>(-1.0) - val; // make the selected experts using val = - 1 - val
}
__syncwarp();
}
// Reset the scores to the original value
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
scores[topk_indices[i]] =
static_cast<double>(-1.0) - static_cast<double>(scores[topk_indices[i]]);
}
}
// Current TE only support float32/bf16/fp16, float64 probs should be considered in the future
#define TE_ROUTER_PROBS_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TE_ROUTER_INDEX_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kInt32: { \
using type = int32_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt64: { \
using type = int64_t; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
} // namespace transformer_engine
#endif
...@@ -229,6 +229,13 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -229,6 +229,13 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
return ret; return ret;
} }
/* cuBLAS version number at run-time */
size_t cublas_version() {
// Cache version to avoid cuBLAS logging overhead
static size_t version = cublasLtGetVersion();
return version;
}
} // namespace } // namespace
#endif // __HIP_PLATFORM_AMD__ #endif // __HIP_PLATFORM_AMD__
...@@ -357,10 +364,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -357,10 +364,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&fastAccuMode, sizeof(fastAccuMode))); &fastAccuMode, sizeof(fastAccuMode)));
// Scaling factors. // Scaling factors.
#if CUDA_VERSION >= 12080 #if CUBLAS_VERSION >= 120800
cublasLtMatmulMatrixScale_t scaling_mode_a; cublasLtMatmulMatrixScale_t scaling_mode_a;
cublasLtMatmulMatrixScale_t scaling_mode_b; cublasLtMatmulMatrixScale_t scaling_mode_b;
#endif #endif // CUBLAS_VERSION >= 120800
if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) { if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) {
void *A_scale_inverse = param.A_scale_inv; void *A_scale_inverse = param.A_scale_inv;
void *B_scale_inverse = param.B_scale_inv; void *B_scale_inverse = param.B_scale_inv;
...@@ -370,10 +377,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -370,10 +377,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse))); &B_scale_inverse, sizeof(B_scale_inverse)));
#if CUDA_VERSION >= 12080 #if CUBLAS_VERSION >= 120800
scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
#endif // CUBLAS_VERSION >= 120800
} else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) { } else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) {
#if CUBLAS_VERSION >= 120800
NVTE_CHECK(cublas_version() >= 120800,
"MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv); fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv); fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
...@@ -386,17 +397,24 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -386,17 +397,24 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
// Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
// CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
if (cublasLtGetVersion() <= 120803) { if (cublas_version() <= 120803) {
const int64_t dummy_a_vec_stride = 1; const int64_t dummy_a_vec_stride = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
sizeof(dummy_a_vec_stride))); sizeof(dummy_a_vec_stride)));
} }
#else
NVTE_ERROR("MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#endif // CUBLAS_VERSION >= 120800
} else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ||
inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) &&
(inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ||
inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
#if CUDA_VERSION >= 12090 #if CUBLAS_VERSION >= 120900
NVTE_CHECK(cublas_version() >= 120900,
"FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ",
cublas_version());
float *A_scale_inverse = reinterpret_cast<float *>(param.A_scale_inv); float *A_scale_inverse = reinterpret_cast<float *>(param.A_scale_inv);
float *B_scale_inverse = reinterpret_cast<float *>(param.B_scale_inv); float *B_scale_inverse = reinterpret_cast<float *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
...@@ -415,20 +433,24 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -415,20 +433,24 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
: CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
#else #else
NVTE_ERROR("FP8 block scaling requires CUDA 12.9+"); NVTE_ERROR("FP8 block scaling requires cuBLAS 12.9+, but compile-time cuBLAS version is ",
#endif // CUDA_VERSION >= 12090 CUBLAS_VERSION);
#endif // CUDA_VERSION >= 12080 #endif // CUBLAS_VERSION >= 120900
} else { } else {
NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " + NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " +
to_string(inputB->scaling_mode) + "."); to_string(inputB->scaling_mode) + ".");
} }
#if CUDA_VERSION >= 12080 #if CUBLAS_VERSION >= 120800
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( if (cublas_version() >= 120800) {
operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( CUBLASLT_MATMUL_DESC_A_SCALE_MODE,
operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode_b, sizeof(scaling_mode_b))); &scaling_mode_a, sizeof(scaling_mode_a)));
#endif NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_MODE,
&scaling_mode_b, sizeof(scaling_mode_b)));
}
#endif // CUBLAS_VERSION >= 120800
if (is_fp8_dtype(outputD->data.dtype)) { if (is_fp8_dtype(outputD->data.dtype)) {
// Accumulation mode not supported for FP8 output // Accumulation mode not supported for FP8 output
C = nullptr; C = nullptr;
...@@ -436,13 +458,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -436,13 +458,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale))); operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
#if CUDA_VERSION >= 12080 #if CUBLAS_VERSION >= 120800
// NOTE: In all current cases where FP8 output is supported, the input is if (cublas_version() >= 120800) {
// scaled identically to the output. // NOTE: In all current cases where FP8 output is supported, the input is
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, // scaled identically to the output.
CUBLASLT_MATMUL_DESC_D_SCALE_MODE, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
&scaling_mode_a, sizeof(scaling_mode_a))); CUBLASLT_MATMUL_DESC_D_SCALE_MODE,
#endif &scaling_mode_a, sizeof(scaling_mode_a)));
}
#endif // CUBLAS_VERSION >= 120800
// For FP8 output, cuBLAS requires C_type to match bias_type and // For FP8 output, cuBLAS requires C_type to match bias_type and
// be FP16/BF16 // be FP16/BF16
const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF; const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF;
...@@ -510,9 +534,24 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -510,9 +534,24 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue))); &epilogue, sizeof(epilogue)));
if (counter != nullptr) {
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ",
CUDA_VERSION);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ",
CUBLAS_VERSION);
#endif
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \ #if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
CUBLAS_VERSION < 130000 CUBLAS_VERSION < 130000
if (counter != nullptr) { NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA verson is ",
cuda::cudart_version());
NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ",
cublas_version());
if (m_split == 0) m_split = 1; if (m_split == 0) m_split = 1;
if (n_split == 0) n_split = 1; if (n_split == 0) n_split = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
...@@ -530,8 +569,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -530,8 +569,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, &counter, operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, &counter,
sizeof(counter))); sizeof(counter)));
} }
}
#endif #endif
}
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
...@@ -723,17 +762,27 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -723,17 +762,27 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
int n_split, bool gemm_producer, const NVTETensor counter, int n_split, bool gemm_producer, const NVTETensor counter,
cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) { cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
NVTE_API_CALL(nvte_cublas_atomic_gemm); NVTE_API_CALL(nvte_cublas_atomic_gemm);
using namespace transformer_engine;
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
int cudart_version; // Check CUDA and cuBLAS versions
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version)); #if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000)
NVTE_CHECK(cudart_version >= 12020 && cudart_version < 13000, NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ",
"Cuda version >=12.2 and <13.0 is required for atomic gemm."); CUDA_VERSION);
NVTE_CHECK(cublasLtGetVersion() >= 120205 && cublasLtGetVersion() < 130000, #endif
"Cublas version >=12.2.5 and <13.0 is required for atomic gemm."); #if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR("Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ",
CUBLAS_VERSION);
#endif
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA verson is ",
cuda::cudart_version());
NVTE_CHECK(
cublas_version() >= 120205 && cublas_version() < 130000,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ",
cublas_version());
#endif #endif
using namespace transformer_engine;
const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B); const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = convertNVTETensor(D); Tensor *outputD = convertNVTETensor(D);
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_H_
#define TRANSFORMER_ENGINE_FUSED_ROUTER_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Apply topk + softmax/sigmoid to the input tensor. Grouped topk is supported.
*
* \param[in] logits Logits from the gating GEMM.
* \param[in] num_tokens Number of tokens.
* \param[in] num_experts Number of experts.
* \param[in] topk Topk value.
* \param[in] use_pre_softmax Whether to use softmax before topk.
* \param[in] num_groups Number of groups in grouped topk.
* \param[in] group_topk Grouped topk value.
* \param[in] scaling_factor Scaling factor.
* \param[in] score_function Score function, 0: sigmoid, 1: softmax.
* \param[in] expert_bias Expert bias. (Only used at the sigmoid case)
* \param[out] probs Output tensor for probabilities.
* \param[out] routing_map Output tensor for routing map.
* \param[out] intermediate_output Output tensor for intermediate output. (Softmax/sigmoid output)
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_topk_with_score_function_forward(
const NVTETensor logits, int num_tokens, int num_experts, int topk, int use_pre_softmax,
int num_groups, int group_topk, float scaling_factor, int score_function,
const NVTETensor expert_bias, NVTETensor probs, NVTETensor routing_map,
NVTETensor intermediate_output, cudaStream_t stream);
/*! \brief Backward pass for fused topk + softmax/sigmoid.
*
* \param[in] routing_map Routing map.
* \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output)
* \param[in] grad_probs Gradient of probs.
* \param[in] num_tokens Number of tokens.
* \param[in] num_experts Number of experts.
* \param[in] topk Topk value.
* \param[in] use_pre_softmax Whether to use softmax before topk.
* \param[in] scaling_factor Scaling factor.
* \param[in] score_function Score function, 0: sigmoid, 1: softmax.
* \param[out] grad_logits Gradient of logits.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_topk_with_score_function_backward(const NVTETensor routing_map,
const NVTETensor intermediate_output,
const NVTETensor grad_probs, int num_tokens,
int num_experts, int topk, int use_pre_softmax,
float scaling_factor, int score_function,
NVTETensor grad_logits, cudaStream_t stream);
/*! \brief Forward pass for computing scores/routing map for auxiliary loss.
*
* \param[in] logits Logits from the gating GEMM.
* \param[in] num_tokens Number of tokens.
* \param[in] num_experts Number of experts.
* \param[in] topk Topk value.
* \param[in] score_function Score function, 0: sigmoid, 1: softmax.
* \param[out] scores Output tensor for scores.
* \param[in] routing_map Routing map.
* \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output)
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_score_for_moe_aux_loss_forward(const NVTETensor logits, int num_tokens,
int num_experts, int topk, int score_function,
NVTETensor scores, const NVTETensor routing_map,
const NVTETensor intermediate_output,
cudaStream_t stream);
/*! \brief Backward pass for computing scores/routing map for auxiliary loss.
*
* \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output)
* \param[in] grad_scores Gradient of scores.
* \param[in] num_tokens Number of tokens.
* \param[in] num_experts Number of experts.
* \param[in] topk Topk value.
* \param[in] score_function Score function, 0: sigmoid, 1: softmax.
* \param[out] grad_logits Gradient of logits.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_output,
const NVTETensor grad_scores, int num_tokens,
int num_experts, int topk, int score_function,
NVTETensor grad_logits, cudaStream_t stream);
/*! \brief Forward pass for auxiliary loss.
*
* \param[in] probs Probabilities from the forward pass.
* \param[in] tokens_per_expert Number of tokens per expert.
* \param[in] total_num_tokens Number of total tokens. Will be used in seq/global aux loss.
* \param[in] num_experts Number of experts.
* \param[in] num_rows Number of rows of probs.
* \param[in] num_cols Number of columns of probs.
* \param[in] topk Topk value.
* \param[in] coeff Coefficient.
* \param[out] aux_loss Output GPU scalar for auxiliary loss.
* \param[out] Const_buf Output GPU scalar for temporary constant buffer for backward pass.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert,
int total_num_tokens, int num_experts, int num_rows,
int num_cols, int topk, float coeff, NVTETensor aux_loss,
NVTETensor Const_buf, cudaStream_t stream);
/*! \brief Backward pass for auxiliary loss.
*
* \param[in] Const_buf Constant buffer from the forward pass.
* \param[in] tokens_per_expert Number of tokens per expert.
* \param[in] num_rows Number of rows of probs.
* \param[in] num_cols Number of columns of probs.
* \param[in] grad_aux_loss Gradient of auxiliary loss.
* \param[out] grad_probs Gradient of probs.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf,
const NVTETensor tokens_per_expert, int num_rows,
int num_cols, NVTETensor grad_aux_loss, NVTETensor grad_probs,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_FUSED_ROPE_H_
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#ifndef TRANSFORMER_ENGINE_MULTI_STREAM_H #ifndef TRANSFORMER_ENGINE_MULTI_STREAM_H
#define TRANSFORMER_ENGINE_MULTI_STREAM_H #define TRANSFORMER_ENGINE_MULTI_STREAM_H
#include "cuda_runtime.h"
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
...@@ -18,6 +20,26 @@ extern "C" { ...@@ -18,6 +20,26 @@ extern "C" {
/*! \brief Number of CUDA streams to use in multi-stream operations */ /*! \brief Number of CUDA streams to use in multi-stream operations */
int nvte_get_num_compute_streams(); int nvte_get_num_compute_streams();
/*! \brief Get a CUDA stream for compute operations.
*
* \param[in] idx Index of the stream to retrieve.Add commentMore actions
* \return A cudaStream_t.
*
* This function returns a CUDA stream that can be used for compute operations.
* The index should be in the range [0, nvte_get_num_compute_streams() - 1].
*/
cudaStream_t nvte_get_compute_stream(const int idx);
/*! \brief Get a CUDA event for compute operations.
*
* \param[in] idx Index of the event to retrieve.
* \return A cudaEvent_t.
*
* This function returns a CUDA event that can be used to synchronize compute operations.
* The index should be in the range [0, nvte_get_num_compute_streams() - 1].
*/
cudaEvent_t nvte_get_compute_stream_event(const int idx);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -44,6 +44,33 @@ extern "C" { ...@@ -44,6 +44,33 @@ extern "C" {
void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list,
const int* padded_num_rows_list, cudaStream_t stream); const int* padded_num_rows_list, cudaStream_t stream);
/*! \brief Unpadding multiple tensors (reverse operation of padding).
*
* NOTE: Unpadding mode only removes bottom rows.
*
* For example, 4x3 matrix unpad to 3x3 matrix.
*
* source
* | 1 | 2 | 3 |
* | 4 | 5 | 6 |
* | 7 | 8 | 9 |
* | 0 | 0 | 0 |
*
* destination
* | 1 | 2 | 3 |
* | 4 | 5 | 6 |
* | 7 | 8 | 9 |
*
* \param[in] num_tensors Number of tensors.
* \param[in] input_list List of 2D padded input tensors.
* \param[in,out] output_list List of unpadded tensors. Dimensions
* match original unpadded tensors.
* \param[in] unpadded_num_rows_list List of unpadded num rows corresponding to input tensors.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_multi_unpadding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list,
const int* unpadded_num_rows_list, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -226,8 +226,8 @@ struct AdamFunctorMasterParamRemainder { ...@@ -226,8 +226,8 @@ struct AdamFunctorMasterParamRemainder {
r_m[ii] = static_cast<MATH_T>(m[i]); r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = static_cast<MATH_T>(v[i]); r_v[ii] = static_cast<MATH_T>(v[i]);
local_p[ii] = static_cast<int16_t>(p[i]); local_p[ii] = p[i];
local_p_rem[ii] = static_cast<int16_t>(p_remainder[i]); local_p_rem[ii] = p_remainder[i];
} else { } else {
r_g[ii] = MATH_T(0); r_g[ii] = MATH_T(0);
r_m[ii] = MATH_T(0); r_m[ii] = MATH_T(0);
...@@ -281,8 +281,8 @@ struct AdamFunctorMasterParamRemainder { ...@@ -281,8 +281,8 @@ struct AdamFunctorMasterParamRemainder {
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x; int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) { if (i < n && i < chunk_size) {
p_remainder[i] = static_cast<int16_t>(local_p_rem[ii]); p_remainder[i] = local_p_rem[ii];
p[i] = static_cast<int16_t>(local_p[ii]); p[i] = local_p[ii];
m[i] = static_cast<FULL_T>(r_m[ii]); m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]); v[i] = static_cast<FULL_T>(r_v[ii]);
...@@ -467,8 +467,8 @@ struct AdamCapturableFunctor { ...@@ -467,8 +467,8 @@ struct AdamCapturableFunctor {
int i = i_start + threadIdx.x + ii * blockDim.x; int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) { if (i < n && i < chunk_size) {
p[i] = static_cast<T>(r_p[ii]); p[i] = static_cast<T>(r_p[ii]);
m[i] = static_cast<T>(r_m[ii]); m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<T>(r_v[ii]); v[i] = static_cast<FULL_T>(r_v[ii]);
} }
} }
} }
...@@ -578,9 +578,6 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -578,9 +578,6 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream) { const float weight_decay, const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) { if (bias_correction == 1) {
...@@ -588,16 +585,48 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -588,16 +585,48 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
bias_correction2 = 1 - std::pow(beta2, step); bias_correction2 = 1 - std::pow(beta2, step);
} }
size_t max_size = 0; // Check tensor list sizes
// 4 tensor lists: g, p, m, v
// 5 tensor lists: g, p, m, v, p_master
const size_t num_tensor_lists = tensor_lists.size();
NVTE_CHECK(num_tensor_lists == 4 || num_tensor_lists == 5,
"Expected 4 or 5 tensor lists, but found ", num_tensor_lists);
const size_t num_tensors_per_list = tensor_lists[0].size();
for (size_t i = 1; i < num_tensor_lists; i++) {
NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i,
" has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list);
}
// Check tensor dtypes
const auto g_in_type_te = tensor_lists[0][0]->dtype();
const auto p_in_type_te = tensor_lists[1][0]->dtype();
for (size_t j = 0; j < num_tensors_per_list; j++) {
NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j,
" has dtype=", to_string(tensor_lists[0][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[1][j]->dtype() == p_in_type_te, "Param tensor ", j,
" has dtype=", to_string(tensor_lists[1][j]->dtype()),
", but expected dtype=", to_string(p_in_type_te));
NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j,
" has dtype=", to_string(tensor_lists[2][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j,
" has dtype=", to_string(tensor_lists[3][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
if (num_tensor_lists == 5) {
NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j,
" has dtype=", to_string(tensor_lists[4][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
}
}
// Check if 64-bit indices are required
bool requires_64bit_indexing = false; bool requires_64bit_indexing = false;
for (size_t i = 0; i < num_tensor_lists; i++) { for (size_t i = 0; i < num_tensor_lists; i++) {
for (size_t j = 0; j < num_tensors_per_list; j++) { for (size_t j = 0; j < num_tensors_per_list; j++) {
if (tensor_lists[i][j]->numel() > max_size) { if (tensor_lists[i][j]->numel() >= INT_MAX) {
max_size = tensor_lists[i][j]->numel(); requires_64bit_indexing = true;
if (max_size >= INT_MAX) { break;
requires_64bit_indexing = true;
break;
}
} }
} }
if (requires_64bit_indexing) { if (requires_64bit_indexing) {
...@@ -605,16 +634,10 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -605,16 +634,10 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
} }
} }
const auto g_in_type_te = tensor_lists[0][0]->dtype(); // Launch kernel
const auto p_in_type_te = tensor_lists[1][0]->dtype();
// case 4: g, p, m, v
// case 5: g, p, m, v, p_master
NVTE_CHECK(num_tensor_lists == 4 || num_tensor_lists == 5, "tensor list must contain 4 or 5");
if (requires_64bit_indexing) { if (requires_64bit_indexing) {
if (num_tensor_lists == 4) { if (num_tensor_lists == 4) {
// Assume single type across p,g,m1,m2 now // g, p, m, v
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type, p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
...@@ -638,7 +661,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -638,7 +661,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
} }
} else { } else {
if (num_tensor_lists == 4) { if (num_tensor_lists == 4) {
// Assume single type across p,g,m1,m2 now // g, p, m, v
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type, p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
...@@ -648,6 +671,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -648,6 +671,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
stream, beta1, beta2, bias_correction1, bias_correction2, stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);)); epsilon, lr, (adamMode_t)mode, weight_decay);));
} else { } else {
// g, p, m, v, p_master
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type, p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
...@@ -668,8 +692,6 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, ...@@ -668,8 +692,6 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
const float epsilon, const int step, const int mode, const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) { const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) { if (bias_correction == 1) {
...@@ -677,23 +699,43 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, ...@@ -677,23 +699,43 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
bias_correction2 = 1 - std::pow(beta2, step); bias_correction2 = 1 - std::pow(beta2, step);
} }
const auto g_in_type_te = tensor_lists[0][0]->dtype(); // Check tensor list sizes
const auto p_in_type_te = tensor_lists[1][0]->dtype(); // 5 tensor lists: g, p, m, v, p_remainder
const size_t num_tensor_lists = tensor_lists.size();
// case 5: g, p, m, v, p_master NVTE_CHECK(num_tensor_lists == 5, "Expected 5 tensor lists, but found ", num_tensor_lists);
NVTE_CHECK(num_tensor_lists == 5, "tensor list must contain 5"); const size_t num_tensors_per_list = tensor_lists[0].size();
NVTE_CHECK(p_in_type_te == DType::kBFloat16, for (size_t i = 1; i < num_tensor_lists; i++) {
"Adam with BF16 param remainders requires BF16 params"); NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i,
" has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list);
}
// g, p, m, v, p_master // Check tensor dtypes
const auto g_in_type_te = tensor_lists[0][0]->dtype();
for (size_t j = 0; j < num_tensors_per_list; j++) {
NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j,
" has dtype=", to_string(tensor_lists[0][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[1][j]->dtype() == DType::kBFloat16, "Param tensor ", j,
" has dtype=", to_string(tensor_lists[1][j]->dtype()),
", but expected dtype=", to_string(DType::kBFloat16));
NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j,
" has dtype=", to_string(tensor_lists[2][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j,
" has dtype=", to_string(tensor_lists[3][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kInt16, "Param remainder tensor ", j,
" has dtype=", to_string(tensor_lists[4][j]->dtype()),
", but expected dtype=", to_string(DType::kInt16));
}
// Launch kernel
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type, g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5>((int64_t)chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 5>((int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(), device_id, AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(), device_id,
stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);); (adamMode_t)mode, weight_decay););
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -703,9 +745,6 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, ...@@ -703,9 +745,6 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay, const DType fp8_dtype, const float weight_decay, const DType fp8_dtype,
const int device_id, cudaStream_t stream) { const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) { if (bias_correction == 1) {
...@@ -713,16 +752,53 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, ...@@ -713,16 +752,53 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
bias_correction2 = 1 - std::pow(beta2, step); bias_correction2 = 1 - std::pow(beta2, step);
} }
size_t max_size = 0; // Check tensor list sizes
// 8 tensor lists: g, p_fp8, m, v, p_master, scale, amax, scale_inv
const size_t num_tensor_lists = tensor_lists.size();
NVTE_CHECK(num_tensor_lists == 8, "Expected 8 tensor lists, but found ", num_tensor_lists);
const size_t num_tensors_per_list = tensor_lists[0].size();
for (size_t i = 1; i < num_tensor_lists; i++) {
NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i,
" has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list);
}
// Check tensor dtypes
const auto g_in_type_te = tensor_lists[0][0]->dtype();
for (size_t j = 0; j < num_tensors_per_list; j++) {
NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j,
" has dtype=", to_string(tensor_lists[0][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(
tensor_lists[1][j]->dtype() == fp8_dtype || tensor_lists[1][j]->dtype() == DType::kByte,
"Param tensor ", j, " has dtype=", to_string(tensor_lists[1][j]->dtype()),
", but expected dtype=", to_string(fp8_dtype));
NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j,
" has dtype=", to_string(tensor_lists[2][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j,
" has dtype=", to_string(tensor_lists[3][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j,
" has dtype=", to_string(tensor_lists[4][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[5][j]->dtype() == DType::kFloat32, "Scale tensor ", j,
" has dtype=", to_string(tensor_lists[5][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[6][j]->dtype() == DType::kFloat32, "Absmax tensor ", j,
" has dtype=", to_string(tensor_lists[6][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[7][j]->dtype() == DType::kFloat32, "Scale-inverse tensor ", j,
" has dtype=", to_string(tensor_lists[7][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
}
// Check if 64-bit indices are required
bool requires_64bit_indexing = false; bool requires_64bit_indexing = false;
for (size_t i = 0; i < num_tensor_lists; i++) { for (size_t i = 0; i < num_tensor_lists; i++) {
for (size_t j = 0; j < num_tensors_per_list; j++) { for (size_t j = 0; j < num_tensors_per_list; j++) {
if (tensor_lists[i][j]->numel() > max_size) { if (tensor_lists[i][j]->numel() >= INT_MAX) {
max_size = tensor_lists[i][j]->numel(); requires_64bit_indexing = true;
if (max_size >= INT_MAX) { break;
requires_64bit_indexing = true;
break;
}
} }
} }
if (requires_64bit_indexing) { if (requires_64bit_indexing) {
...@@ -730,11 +806,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, ...@@ -730,11 +806,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
} }
} }
const auto g_in_type_te = tensor_lists[0][0]->dtype(); // Launch kernel
// case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv
NVTE_CHECK(num_tensor_lists == 8, "tensor list must contain 8 tensors");
if (requires_64bit_indexing) { if (requires_64bit_indexing) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T, fp8_dtype, FP8_T,
...@@ -765,6 +837,34 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag, ...@@ -765,6 +837,34 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
Tensor step, const int mode, const int bias_correction, Tensor step, const int mode, const int bias_correction,
const float weight_decay, Tensor inv_scale, const float weight_decay, Tensor inv_scale,
const int device_id, cudaStream_t stream) { const int device_id, cudaStream_t stream) {
// Check tensor list sizes
// 4 tensor lists: g, p, m, v
const size_t num_tensor_lists = tensor_lists.size();
NVTE_CHECK(num_tensor_lists == 4, "Expected 4 tensor lists, but found ", num_tensor_lists);
const size_t num_tensors_per_list = tensor_lists[0].size();
for (size_t i = 1; i < num_tensor_lists; i++) {
NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i,
" has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list);
}
// Check tensor dtypes
const auto g_in_type_te = tensor_lists[0][0]->dtype();
for (size_t j = 0; j < num_tensors_per_list; j++) {
NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j,
" has dtype=", to_string(tensor_lists[0][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[1][j]->dtype() == g_in_type_te, "Param tensor ", j,
" has dtype=", to_string(tensor_lists[1][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j,
" has dtype=", to_string(tensor_lists[2][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j,
" has dtype=", to_string(tensor_lists[3][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
}
// Launch kernel
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype, tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
...@@ -783,6 +883,37 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag, ...@@ -783,6 +883,37 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
Tensor inv_scale, const int device_id, Tensor inv_scale, const int device_id,
cudaStream_t stream) { cudaStream_t stream) {
// Check tensor list sizes
// 4 tensor lists: g, p, m, v, p_master
const size_t num_tensor_lists = tensor_lists.size();
NVTE_CHECK(num_tensor_lists == 5, "Expected 4 tensor lists, but found ", num_tensor_lists);
const size_t num_tensors_per_list = tensor_lists[0].size();
for (size_t i = 1; i < num_tensor_lists; i++) {
NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i,
" has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list);
}
// Check tensor dtypes
const auto g_in_type_te = tensor_lists[0][0]->dtype();
for (size_t j = 0; j < num_tensors_per_list; j++) {
NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j,
" has dtype=", to_string(tensor_lists[0][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[1][j]->dtype() == g_in_type_te, "Param tensor ", j,
" has dtype=", to_string(tensor_lists[1][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j,
" has dtype=", to_string(tensor_lists[2][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j,
" has dtype=", to_string(tensor_lists[3][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j,
" has dtype=", to_string(tensor_lists[4][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
}
// Launch kernel
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype, tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<BLOCK_SIZE, 5>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 5>(chunk_size, noop_flag, tensor_lists,
......
...@@ -52,7 +52,7 @@ class OptionalCUDAGuard { ...@@ -52,7 +52,7 @@ class OptionalCUDAGuard {
~OptionalCUDAGuard() { ~OptionalCUDAGuard() {
if (device_changed_) { if (device_changed_) {
NVTE_CHECK_CUDA(cudaSetDevice(prev_device_)); cudaSetDevice(prev_device_);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment