Commit c1a1c04e authored by wenjh's avatar wenjh
Browse files

Merge nv_main(2.10) to main


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents e698a0a7 66aed3ae
......@@ -17,6 +17,48 @@ from transformer_engine.pytorch import (
Float8CurrentScalingQuantizer,
)
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import (
nvfp4_ref_rht_2d_quantizer_factory,
)
@pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear"])
def test_custom_recipe_sanity_modules_nvfp4(module_type):
"""Test modules with NVFP4 custom recipe support"""
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(0)
# Simple linear layer with dims divisible by 16
in_features = 64
out_features = 64
batch = 32
if module_type == "Linear":
model = Linear(in_features, out_features, params_dtype=torch.bfloat16, bias=False).cuda()
elif module_type == "LayerNormLinear":
model = LayerNormLinear(
in_features, out_features, params_dtype=torch.bfloat16, bias=False
).cuda()
else: # OpsLinear
model = te_ops.Linear(
in_features, out_features, device="cuda", dtype=torch.bfloat16, bias=False
)
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)
# Use NVFP4 quantizer factory
custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory)
# Execute with custom recipe
with autocast(enabled=True, recipe=custom_recipe):
out = model(inp)
loss = out.float().sum()
loss.backward()
# Basic sanity: gradients exist
assert inp.grad is not None
@pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear", "LayerNormMLP"])
......
......@@ -58,10 +58,6 @@ def test_fused_rope(
# are with the maximum length of the rope embeddings.
pytest.skip("Skipping test with margin=0 and start_positions=True")
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
t = torch.rand(
......@@ -102,11 +98,8 @@ def test_fused_rope(
cp_rank=cp_rank,
).to(dtype)
loss_unfused = loss_func(output_unfused)
if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
......@@ -121,17 +114,12 @@ def test_fused_rope(
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)
if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
......@@ -156,10 +144,6 @@ def test_fused_rope_thd(
margin: int,
) -> None:
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048]
......@@ -214,8 +198,6 @@ def test_fused_rope_thd(
cp_rank=cp_rank,
).to(dtype)
loss_unfused = loss_func(output_unfused)
if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
......@@ -233,20 +215,144 @@ def test_fused_rope_thd(
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)
if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
@pytest.mark.parametrize("start_positions", [False, True])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [1.0])
@pytest.mark.parametrize("loss_func", [_overlapping_grad])
@pytest.mark.parametrize("cp_size", [2])
@pytest.mark.parametrize("interleaved", [False, True])
def test_unfused_rope_thd_vs_bshd(
dtype: torch.dtype,
hidden_size: int,
rotary_percent: float,
loss_func: Callable,
cp_size: int,
interleaved: bool,
start_positions: bool,
) -> None:
"""
This is just a sanity check to ensure that the unfused RoPE in THD/SBHD/BSHD
formats are the same.
"""
device = torch.device("cuda:0")
seqlen, max_seqlen = 16, 2048
batch_size, head_num = 4, 256
# NOTE: dtype=torch.int32 is important, otherwise the cumsum will be in int64 and
# that causes unexpected issues.
seq_lens = torch.tensor([seqlen for _ in range(batch_size)], dtype=torch.int32)
cu_seqlens = torch.cumsum(torch.cat([torch.zeros(1, dtype=torch.int32), seq_lens]), dim=0).to(
device=device, dtype=torch.int32
)
# Create a tensor in THD format
thd = torch.rand(
(cu_seqlens[-1] // cp_size, head_num, hidden_size),
dtype=dtype,
device=device,
)
thd.requires_grad = True
# Clone the tensor to create a tensor in BSHD format
bshd = thd.view(batch_size, -1, head_num, hidden_size).clone().detach()
bshd = bshd.to(dtype=dtype, device=device)
bshd.requires_grad = True
# Clone the tensor to create a tensor in SBHD format
sbhd = bshd.transpose(1, 0).clone().detach()
sbhd = sbhd.to(dtype=dtype, device=device)
sbhd.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb = rotary_pos_emb(max_seqlen)
assert emb.is_contiguous()
start_positions = cu_seqlens[:-1] if start_positions else None
for cp_rank in range(cp_size):
# unfused bshd
output_unfused_bshd = apply_rotary_pos_emb(
bshd.float(),
emb,
start_positions=start_positions,
interleaved=interleaved,
fused=False,
tensor_format="bshd",
cu_seqlens=cu_seqlens,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
loss_unfused_bshd = loss_func(output_unfused_bshd)
loss_unfused_bshd.backward()
grad_unfused_bshd = bshd.grad.detach().clone()
bshd.grad = None
# unfused sbhd
output_unfused_sbhd = apply_rotary_pos_emb(
sbhd.float(),
emb,
start_positions=start_positions,
interleaved=interleaved,
fused=False,
tensor_format="sbhd",
cu_seqlens=cu_seqlens,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
loss_unfused_sbhd = loss_func(output_unfused_sbhd)
loss_unfused_sbhd.backward()
grad_unfused_sbhd = sbhd.grad.detach().clone()
sbhd.grad = None
# unfused thd
output_unfused_thd = apply_rotary_pos_emb(
thd.float(),
emb,
start_positions=start_positions,
tensor_format="thd",
interleaved=interleaved,
fused=False,
cu_seqlens=cu_seqlens,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
loss_unfused_thd = loss_func(output_unfused_thd)
loss_unfused_thd.backward()
grad_unfused_thd = thd.grad.detach().clone()
thd.grad = None
torch.testing.assert_close(
output_unfused_bshd.reshape(*output_unfused_thd.shape), output_unfused_thd
)
torch.testing.assert_close(
output_unfused_sbhd.transpose(1, 0).reshape(*output_unfused_thd.shape),
output_unfused_thd,
)
torch.testing.assert_close(
grad_unfused_bshd.reshape(*grad_unfused_thd.shape), grad_unfused_thd
)
torch.testing.assert_close(
grad_unfused_sbhd.transpose(1, 0).reshape(*grad_unfused_thd.shape), grad_unfused_thd
)
assert output_unfused_thd.is_contiguous()
assert output_unfused_bshd.is_contiguous()
assert output_unfused_sbhd.is_contiguous()
@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2, 8, 2048, 4096])
......
......@@ -41,21 +41,22 @@ from transformer_engine.pytorch import (
is_mxfp8_available,
is_fp8_block_scaling_available,
is_bf16_available,
is_nvfp4_available,
)
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states, get_available_attention_backends
from utils import ModelConfig, reset_rng_states
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
fp8_block_scaling_available = is_fp8_block_scaling_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available(return_reason=True)
nvfp4_available = is_nvfp4_available()
sm_80plus = get_device_compute_capability() >= (8, 0)
......@@ -120,6 +121,43 @@ if NVTE_TEST_NVINSPECT_ENABLED:
)
def nvfp4_rht_and_2d_quantization():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
random_hadamard_transform=False, fp4_2d_quantization=True
)
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
return nvfp4_recipe
def check_rht_usage(recipe: recipe.Recipe) -> bool:
# if using RHT, we can only support bf16
# check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
if recipe.nvfp4():
if (
recipe.fp4_quant_fwd_inp.random_hadamard_transform
or recipe.fp4_quant_fwd_weight.random_hadamard_transform
or recipe.fp4_quant_bwd_grad.random_hadamard_transform
):
return True
return False
def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool:
supported_input_dtypes = []
if recipe.nvfp4():
supported_input_dtypes.append(torch.bfloat16)
# if not using RHT, we can add fp32 as well
if not check_rht_usage(recipe):
supported_input_dtypes.append(torch.float32)
return supported_input_dtypes
fp8_recipes = []
if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
......@@ -128,6 +166,8 @@ if fp8_block_scaling_available:
if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
if nvfp4_available:
fp8_recipes.append(nvfp4_rht_and_2d_quantization())
use_cutlass_grouped_gemm = [False]
# Only enable cutlass grouped gemm on Hopper
......@@ -135,23 +175,6 @@ if torch.cuda.get_device_capability() == (9, 0):
use_cutlass_grouped_gemm.append(True)
def is_fused_attn_available(
config: ModelConfig,
dtype: torch.dtype,
qkv_layout="bshd_bshd_bshd",
is_training=True,
deterministic=False,
):
_, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
deterministic=deterministic,
)
return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends
def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
......@@ -612,6 +635,11 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
config = model_configs[model]
......@@ -729,6 +757,11 @@ def test_gpt_full_activation_recompute(
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
config = model_configs[model]
......@@ -872,8 +905,6 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@pytest.mark.parametrize("model", ["126m"])
def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model]
if not is_fused_attn_available(config, dtype, deterministic=True):
pytest.skip("No attention backend available.")
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
......@@ -920,10 +951,6 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model]
if not is_fused_attn_available(
config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True
):
pytest.skip("No attention backend available.")
te_gpt = TransformerLayer(
hidden_size=config.hidden_size,
......@@ -1035,10 +1062,6 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model]
if not is_fused_attn_available(
config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True
):
pytest.skip("No attention backend available.")
te_mha = MultiheadAttention(
config.hidden_size,
......@@ -1327,6 +1350,12 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
if recipe is not None and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
te_linear_ref = Linear(
config.hidden_size,
......@@ -1770,8 +1799,8 @@ def _test_grouped_linear_accuracy(
split_size = 1
if fp8:
split_size = 16
if recipe.mxfp8():
split_size = 128
if recipe.mxfp8() or recipe.nvfp4():
split_size = 32
m = config.max_seqlen_q // split_size
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
......@@ -1849,6 +1878,12 @@ def test_grouped_linear_accuracy(
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
if recipe is not None and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear(
num_gemms,
......@@ -1993,6 +2028,12 @@ def test_grouped_linear_accuracy_save_original_input(
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
if recipe is not None and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear(
num_gemms,
......@@ -2086,7 +2127,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
align_size = 16
if recipe.mxfp8():
if recipe.mxfp8() or recipe.nvfp4():
align_size = 32
padded_tokens_per_expert = [
(num_tokens + align_size - 1) // align_size * align_size
......@@ -2207,6 +2248,12 @@ def test_padding_grouped_linear_accuracy(
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
if recipe is not None and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding(
num_gemms,
......@@ -2284,6 +2331,12 @@ def test_padding_grouped_linear_accuracy_save_original_input(
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
if recipe is not None and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding(
num_gemms,
......@@ -2499,6 +2552,12 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe):
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
config = model_configs[model]
outputs = _test_gpt_fp8_parameters(bs, dtype, config, False, recipe)
......
......@@ -68,7 +68,7 @@ if fp8_available:
fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "clamped_swiglu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
......
......@@ -123,6 +123,7 @@ all_activations = [
"sreglu",
"silu",
"swiglu",
"clamped_swiglu",
]
all_normalizations = ["LayerNorm", "RMSNorm"]
......@@ -566,7 +567,7 @@ def test_sanity_layernorm_mlp(
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
activation_params = None if activation != "clamped_swiglu" else {"limit": 7.0, "alpha": 1.702}
block = LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
......@@ -574,6 +575,7 @@ def test_sanity_layernorm_mlp(
output_layer_init_method=output_layer_init_method,
zero_centered_gamma=zero_centered_gamma,
activation=activation,
activation_params=activation_params,
normalization=normalization,
params_dtype=dtype,
device="cuda",
......
......@@ -205,6 +205,7 @@ class ModelConfig:
window_size: Tuple[int, int] = (-1, -1),
context_parallel: bool = False,
cp_comm_type: str = "p2p",
return_max_logit=False,
total_requests: int = None,
max_ctx_len: int = None,
num_layers: int = 1,
......@@ -233,6 +234,7 @@ class ModelConfig:
self.window_size = check_set_window_size(self.attn_mask_type, window_size)
self.context_parallel = context_parallel
self.cp_comm_type = cp_comm_type
self.return_max_logit = return_max_logit
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
......@@ -318,6 +320,7 @@ def get_available_attention_backends(
is_training=is_training,
inference_params=inference_params,
softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit,
)
(
use_flash_attention,
......
......@@ -29,15 +29,6 @@ endif()
# Language options
if(USE_CUDA)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
......@@ -54,8 +45,62 @@ if(USE_CUDA)
# CUDA Toolkit
find_package(CUDAToolkit REQUIRED)
if (CUDAToolkit_VERSION VERSION_LESS 12.0)
message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}")
if (CUDAToolkit_VERSION VERSION_LESS 12.1)
message(FATAL_ERROR "CUDA 12.1+ is required, but found CUDA ${CUDAToolkit_VERSION}")
endif()
# Process GPU architectures
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
endif()
# Process CMAKE_CUDA_ARCHITECTURES to separate generic and specific architectures
set(NVTE_GENERIC_ARCHS)
set(NVTE_SPECIFIC_ARCHS)
# Check for architecture 100
list(FIND CMAKE_CUDA_ARCHITECTURES "100" arch_100_index)
if(NOT arch_100_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "100")
list(APPEND NVTE_GENERIC_ARCHS "100")
list(APPEND NVTE_SPECIFIC_ARCHS "100a")
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9)
list(APPEND NVTE_SPECIFIC_ARCHS "103a")
endif()
endif()
# Check for architecture 101 (if we see this we are in toolkit <= 12.9)
list(FIND CMAKE_CUDA_ARCHITECTURES "101" arch_101_index)
if(NOT arch_101_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "101")
list(APPEND NVTE_GENERIC_ARCHS "101")
list(APPEND NVTE_SPECIFIC_ARCHS "101a")
endif()
# Check for architecture 110 (if we see this we are in toolkit >= 13.0)
list(FIND CMAKE_CUDA_ARCHITECTURES "110" arch_110_index)
if(NOT arch_110_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "110")
list(APPEND NVTE_GENERIC_ARCHS "110")
list(APPEND NVTE_SPECIFIC_ARCHS "110f")
endif()
# Check for architecture 120
list(FIND CMAKE_CUDA_ARCHITECTURES "120" arch_120_index)
if(NOT arch_120_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "120")
list(APPEND NVTE_GENERIC_ARCHS "120")
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9)
list(APPEND NVTE_SPECIFIC_ARCHS "120f")
else()
list(APPEND NVTE_SPECIFIC_ARCHS "120a")
endif()
endif()
# cuDNN frontend API
......@@ -110,38 +155,32 @@ set(CUTLASS_TOOLS_INCLUDE_DIR
# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
if(USE_CUDA)
# NVIDIA MathDX include directory (from Python package install location)
if(NOT DEFINED MATHDX_INCLUDE_DIR)
execute_process(
COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx
OUTPUT_VARIABLE _PIP_SHOW_MATHDX
ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR
RESULT_VARIABLE _PIP_SHOW_MATHDX_RES
OUTPUT_STRIP_TRAILING_WHITESPACE)
if(NOT _PIP_SHOW_MATHDX_RES EQUAL 0)
message(FATAL_ERROR "Failed to query 'nvidia-mathdx' with pip (using ${Python_EXECUTABLE}): ${_PIP_SHOW_MATHDX_ERR}")
endif()
string(REGEX MATCH "Location: ([^\n\r]+)" _MATHDX_LOC_MATCH "${_PIP_SHOW_MATHDX}")
if(NOT _MATHDX_LOC_MATCH)
message(FATAL_ERROR "Could not parse installation location for 'nvidia-mathdx'. Output was:\n${_PIP_SHOW_MATHDX}")
endif()
set(MATHDX_LOCATION "${CMAKE_MATCH_1}")
set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include")
endif()
if(NOT EXISTS "${MATHDX_INCLUDE_DIR}")
message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.")
endif()
endif()
# Configure Transformer Engine library
include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES)
set(transformer_engine_cpp_sources)
set(transformer_engine_cuda_sources)
set(transformer_engine_cuda_arch_specific_sources)
if(USE_CUDA)
list(APPEND transformer_engine_SOURCES
list(APPEND transformer_engine_cpp_sources
cudnn_utils.cpp
transformer_engine.cpp
fused_attn/fused_attn.cpp
gemm/config.cpp
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/rmsnorm/rmsnorm_api.cpp
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/comm_gemm_overlap.cpp)
list(APPEND transformer_engine_cuda_sources
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
......@@ -153,40 +192,23 @@ if(USE_CUDA)
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
activation/gelu.cu
dropout/dropout.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu
activation/swiglu.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/config.cpp
gemm/cublaslt_gemm.cu
gemm/cutlass_grouped_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
......@@ -200,26 +222,91 @@ if(USE_CUDA)
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers.cu)
list(APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
cast/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)
hadamard_transform/hadamard_transform_cast_fusion.cu)
# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
list(APPEND transformer_engine_SOURCES ${transformer_engine_cuda_arch_specific_sources}
${transformer_engine_cuda_sources}
${transformer_engine_cpp_sources})
# Set compile options for CUDA sources with generic architectures
foreach(cuda_source IN LISTS transformer_engine_cuda_sources)
set(arch_compile_options)
foreach(arch IN LISTS NVTE_GENERIC_ARCHS)
list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}")
endforeach()
if(arch_compile_options)
set_property(
SOURCE ${cuda_source}
APPEND
PROPERTY
COMPILE_OPTIONS ${arch_compile_options}
)
endif()
endforeach()
# Set compile options for CUDA sources with specific architectures
foreach(cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources)
set(arch_compile_options)
foreach(arch IN LISTS NVTE_SPECIFIC_ARCHS)
list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}")
endforeach()
if(arch_compile_options)
set_property(
SOURCE ${cuda_source}
APPEND
PROPERTY
COMPILE_OPTIONS ${arch_compile_options}
)
endif()
endforeach()
if (NVTE_WITH_CUBLASMP)
list(APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp)
endif()
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
# CUTLASS kernels require SM90a and cause hang in debug build
set_property(
SOURCE gemm/cutlass_grouped_gemm.cu
APPEND
PROPERTY
COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a;-g0")
else()
list(APPEND transformer_engine_SOURCES
list(APPEND transformer_engine_cpp_sources
cudnn_utils.cpp
transformer_engine.cpp
gemm/config.cpp
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/rmsnorm/rmsnorm_api.cpp
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/comm_gemm_overlap.cpp)
list(APPEND transformer_engine_cuda_sources
common.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/utils.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
......@@ -230,31 +317,21 @@ else()
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu
dropout/dropout.cu
activation/relu.cu
activation/swiglu.cu
gemm/config.cpp
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/hipblas_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
......@@ -267,10 +344,22 @@ else()
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)
recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers.cu)
list(APPEND transformer_engine_cuda_arch_specific_sources
util/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu)
# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
list(APPEND transformer_engine_SOURCES ${transformer_engine_cuda_arch_specific_sources}
${transformer_engine_cuda_sources}
${transformer_engine_cpp_sources})
if (NVTE_WITH_CUBLASMP)
list(APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp)
......@@ -311,14 +400,16 @@ else()
add_library(transformer_engine SHARED ${te_hip_sources})
endif()
target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include")
# Configure dependencies
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
if (USE_CUDA)
# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
CUDNN::cudnn_all)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine SYSTEM PRIVATE
......@@ -439,7 +530,8 @@ target_include_directories(transformer_engine PRIVATE
"${CMAKE_CURRENT_BINARY_DIR}/string_headers")
# Compiler options
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
set(nvte_sources_with_fast_math)
list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
multi_tensor/adam.cu
......@@ -449,20 +541,23 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
multi_tensor/sgd.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
fused_attn/kv_cache.cu)
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
set_source_files_properties(activation/gelu.cu
list(APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
util/cast.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
activation/swiglu.cu)
endif()
if(USE_CUDA)
foreach(cuda_source IN LISTS nvte_sources_with_fast_math)
set_property(
SOURCE ${cuda_source}
APPEND
PROPERTY
COMPILE_OPTIONS "--use_fast_math")
endforeach()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
else()
......@@ -491,10 +586,10 @@ else()
endif()
# Number of parallel build jobs
if(ENV{MAX_JOBS})
set(BUILD_JOBS_STR "$ENV{MAX_JOBS}")
elseif(ENV{NVTE_BUILD_MAX_JOBS})
set(BUILD_JOBS_STR "$ENV{NVTE_BUILD_MAX_JOBS}")
if($ENV{MAX_JOBS})
set(BUILD_JOBS_STR $ENV{MAX_JOBS})
elseif($ENV{NVTE_BUILD_MAX_JOBS})
set(BUILD_JOBS_STR $ENV{NVTE_BUILD_MAX_JOBS})
else()
set(BUILD_JOBS_STR "max")
endif()
......
......@@ -8,22 +8,19 @@ import ctypes
import functools
import glob
import importlib
from importlib.metadata import version, metadata, PackageNotFoundError
import logging
from importlib.metadata import version, distribution, PackageNotFoundError
import os
from pathlib import Path
import platform
import subprocess
import sys
import sysconfig
from typing import Optional
from typing import Optional, Tuple
from torch.utils.cpp_extension import IS_HIP_EXTENSION
_logger = logging.getLogger(__name__)
@functools.lru_cache(maxsize=None)
def _is_pip_package_installed(package) -> bool:
def _is_package_installed(package) -> bool:
"""Check if the given package is installed via pip."""
# This is needed because we only want to return true
......@@ -31,12 +28,34 @@ def _is_pip_package_installed(package) -> bool:
# if it's importable in the current directory due to
# the presence of the shared library module.
try:
metadata(package)
distribution(package)
except PackageNotFoundError:
return False
return True
@functools.lru_cache(maxsize=None)
def _is_package_installed_from_wheel(package) -> bool:
"""Check if the given package is installed via PyPI."""
if not _is_package_installed(package):
return False
te_dist = distribution(package)
te_wheel_file = ""
for file_path in te_dist.files:
if file_path.name == "WHEEL":
te_wheel_file = te_dist.locate_file("") / file_path
if not te_wheel_file:
return False
with te_wheel_file.open("r") as f:
for line in f:
if line.startswith("Root-Is-Purelib:"):
return line.strip().split(":")[1].strip().lower() == "true"
return False
@functools.lru_cache(maxsize=None)
def _find_shared_object_in_te_dir(te_path: Path, prefix: str) -> Optional[Path]:
"""
......@@ -112,6 +131,19 @@ def _get_shared_object_file(library: str) -> Path:
)
def get_te_core_package_info() -> Tuple[bool, str, str]:
"""
Check if Tranformer Engine core package is installed.
Returns the module name and version if found.
"""
te_core_packages = ("transformer-engine-cu12", "transformer-engine-cu13")
for package in te_core_packages:
if _is_package_installed(package):
return True, package, version(package)
return False, "", ""
@functools.lru_cache(maxsize=None)
def load_framework_extension(framework: str) -> None:
"""
......@@ -130,37 +162,28 @@ def load_framework_extension(framework: str) -> None:
if framework == "torch":
extra_dep_name = "pytorch"
# Find the TE packages. The core and framework packages can only be installed via PyPI.
# For the `transformer-engine` package, we need to check explicity.
te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info()
te_framework_installed = _is_package_installed(module_name)
te_installed = _is_package_installed("transformer_engine")
te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine")
assert te_installed, "Could not find `transformer_engine`."
# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version.
if _is_pip_package_installed(module_name):
assert _is_pip_package_installed(
"transformer_engine"
), "Could not find `transformer-engine`."
assert _is_pip_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'"
)
# extension are all installed via PyPI and have matching versions.
if te_framework_installed:
assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package."
assert te_core_installed, "Could not find TE core package `transformer-engine-cu*`."
# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
# Note: Should we error? This is a rare use case.
if _is_pip_package_installed("transformer-engine-cu12"):
if not _is_pip_package_installed(module_name):
_logger.info(
"Could not find package %s. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'",
module_name,
assert version(module_name) == version("transformer-engine") == te_core_version, (
"Transformer Engine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and {te_core_package_name}"
f" v{te_core_version}. Install transformer-engine using "
f"'pip3 install --no-build-isolation transformer-engine[{extra_dep_name}]==VERSION'"
)
# After all checks are completed, load the shared object file.
......@@ -170,6 +193,35 @@ def load_framework_extension(framework: str) -> None:
spec.loader.exec_module(solib)
def sanity_checks_for_pypi_installation() -> None:
"""Ensure that package is installed correctly if using PyPI."""
te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info()
te_installed = _is_package_installed("transformer_engine")
te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine")
assert te_installed, "Could not find `transformer-engine`."
# If the core package is installed via PyPI.
if te_core_installed:
assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package."
assert version("transformer-engine") == te_core_version, (
"Transformer Engine package version mismatch. Found "
f"transformer-engine v{version('transformer-engine')} "
f"and {te_core_package_name} v{te_core_version}."
)
# Only the metapackage is found, invalid usecase.
elif te_installed_via_pypi:
raise RuntimeError(
"Found empty `transformer-engine` meta package installed. "
"Install `transformer-engine` with framework extensions via"
"'pip3 install --no-build-isolation transformer-engine[pytorch,jax]==VERSION'"
" or 'pip3 install transformer-engine[core]` for the TE core lib only. The `core_cu12`"
" or `core_cu13` extra deps can be used to specify CUDA version for the TE core lib."
)
@functools.lru_cache(maxsize=None)
def _get_sys_extension() -> str:
"""File extension for shared objects."""
......@@ -253,9 +305,7 @@ def _load_cudnn():
if not IS_HIP_EXTENSION:
# Attempt to locate libcudnn via ldconfig
libs = subprocess.check_output(
f"ldconfig -p | grep 'libcudnn{_get_sys_extension()}'", shell=True
)
libs = subprocess.check_output(["ldconfig", "-p"])
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
......@@ -285,9 +335,7 @@ def _load_nvrtc():
return handle
# Attempt to locate NVRTC via ldconfig
libs = subprocess.check_output(
f"ldconfig -p | grep 'libnvrtc{_get_sys_extension()}'", shell=True
)
libs = subprocess.check_output(["ldconfig", "-p"])
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
......@@ -317,9 +365,7 @@ def _load_curand():
return handle
# Attempt to locate cuRAND via ldconfig
libs = subprocess.check_output(
f"ldconfig -p | grep 'libcurand{_get_sys_extension()}'", shell=True
)
libs = subprocess.check_output(["ldconfig", "-p"])
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
......@@ -340,15 +386,16 @@ def _load_core_library():
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
try:
sanity_checks_for_pypi_installation()
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
_CURAND_LIB_CTYPES = _load_curand()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
_TE_LIB_CTYPES = _load_core_library()
# Needed to find the correct headers for NVRTC kernels.
if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir():
os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir()
except OSError:
pass
_TE_LIB_CTYPES = _load_core_library()
......@@ -14,26 +14,17 @@
#include <cuda_runtime.h>
#include <transformer_engine/activation.h>
#include "../cast/dispatch/gated.cuh"
#include "../cast/dispatch/quantize.cuh"
#include "../common.h"
#include "../util/cast_gated_kernels.cuh"
#include "../util/cast_kernels.cuh"
#include "../util/math.h"
#include "../util/vectorized_pointwise.h"
namespace transformer_engine {
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = true;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
nullptr, stream);
dispatch::quantize_fwd_helper<IS_ACT, Empty, OP>(input, output, nullptr, stream);
}
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
......@@ -42,20 +33,17 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
using namespace detail;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, OP>(grad, input, output, dbias, workspace,
nullptr, stream);
}
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = false;
constexpr NVTETensor grad = nullptr;
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, p, stream);
dispatch::quantize_gated_fwd_helper<Param, ActOP>(input, output, p, stream);
}
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
......@@ -63,8 +51,7 @@ template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p,
cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = true;
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, p, stream);
dispatch::quantize_gated_bwd_helper<Param, ActOP, DActOP>(grad, input, output, p, stream);
}
} // namespace transformer_engine
......
......@@ -20,6 +20,19 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
......@@ -48,6 +61,19 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dqgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
......
......@@ -20,6 +20,19 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, drelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_drelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, drelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine;
......@@ -48,6 +61,19 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dsrelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsrelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine;
......
......@@ -20,6 +20,19 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, dsilu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dsilu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine;
......
......@@ -12,36 +12,20 @@
#include <transformer_engine/cast.h>
#include <transformer_engine/multi_stream.h>
#include <cfloat>
#include <limits>
#include <mutex>
#include <string>
#include "../common.h"
#include "../transpose/cast_transpose.h"
#include "../util/multi_stream.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "cast_kernels.cuh"
#include "dequantize_kernels.cuh"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/activation.h"
#include "dispatch/dequantize.cuh"
#include "dispatch/quantize.cuh"
#include "transformer_engine/transpose.h"
void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize);
using namespace transformer_engine;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, output, dbias,
workspace, nullptr, stream);
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
}
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
......@@ -61,15 +45,8 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
NVTE_API_CALL(nvte_quantize_v2);
using namespace transformer_engine;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
input, grad, output, dbias, workspace, quant_config, stream);
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, quant_config, stream);
}
void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
......@@ -79,87 +56,17 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr const NVTETensor activation_input = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dgelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dsilu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsilu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_drelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, drelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dqgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dqgelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dsrelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsrelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, nullptr>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dequantize);
using namespace transformer_engine;
detail::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
dispatch::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output),
stream);
}
void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
......@@ -168,12 +75,7 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
NVTE_API_CALL(nvte_multi_tensor_quantize);
using namespace transformer_engine;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
const size_t num_streams = nvte_get_num_compute_streams();
......@@ -186,9 +88,8 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
}
for (int i = 0; i < num_tensors; i++) {
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
inputs[i], grad, outputs[i], dbias, workspace, nullptr,
detail::get_compute_stream(i % num_streams));
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(
inputs[i], outputs[i], quant_configs, detail::get_compute_stream(i % num_streams));
}
// record events on compute streams
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file common.cuh
* \brief Common functions in quantize.
*/
#ifndef TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
#define TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../utils.cuh"
namespace transformer_engine {
namespace dispatch {
namespace common {
inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) {
const size_t N = product(t->data.shape);
const bool isFullTile = (N % elems_per_block == 0);
return isFullTile;
}
inline bool dimensions_supported_by_TMA(const Tensor *const t) {
const size_t cols = t->flat_last_dim();
constexpr size_t TMA_bytes = 16;
const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype());
return cols % alignment_requirement == 0;
}
namespace kernel {
constexpr size_t THREADS_PER_BLOCK = 256;
template <int nvec, typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial,
const size_t rows, const size_t cols) {
using ComputeVec = Vec<float, nvec>;
using OutputVec = Vec<OType, nvec>;
const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id * nvec >= cols) {
return;
}
const float *const thread_in_base = dbias_partial + thread_id * nvec;
OType *const thread_out_base = dbias_output + thread_id * nvec;
ComputeVec ldg_vec;
ComputeVec acc_vec;
acc_vec.clear();
for (int i = 0; i < rows; ++i) {
ldg_vec.load_from(thread_in_base + i * cols);
#pragma unroll
for (int e = 0; e < nvec; ++e) {
acc_vec.data.elt[e] += ldg_vec.data.elt[e];
}
}
OutputVec stg_vec;
#pragma unroll
for (int e = 0; e < nvec; ++e) {
stg_vec.data.elt[e] = static_cast<OType>(acc_vec.data.elt[e]);
}
stg_vec.store_to(thread_out_base);
}
} // namespace kernel
template <typename IType>
void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols,
cudaStream_t stream) {
using namespace kernel;
constexpr size_t reduce_dbias_store_bytes = 8; // stg.64
constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType);
NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape.");
const size_t reduce_dbias_num_blocks = DIVUP(cols, THREADS_PER_BLOCK * reduce_dbias_nvec);
reduce_dbias_kernel<reduce_dbias_nvec, IType>
<<<reduce_dbias_num_blocks, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, rows, cols);
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace common
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file dequantize.cuh
* \brief Dequantize dispatcher.
*/
#ifndef TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_
#define TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../fp8/dequantize_fp8.cuh"
#include "../mxfp8/dequantize_mxfp8.cuh"
#include "../nvfp4/dequantize_nvfp4.cuh"
namespace transformer_engine {
namespace dispatch {
inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
switch (input.scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
NVTE_CHECK(is_fp8_dtype(input.data.dtype) || is_int8_dtype(input.data.dtype), "Input must have FP8 or INT8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype) && !is_int8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
fp8::dequantize(input, output, stream);
break;
}
case NVTE_MXFP8_1D_SCALING: {
if (is_supported_by_CC_100()) {
mxfp8::dequantize(input, output, stream);
} else {
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0");
}
break;
}
case NVTE_NVFP4_1D_SCALING: {
nvfp4::dequantize(input, output, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + ".");
}
}
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file gated.cuh
* \brief Gated dispatcher.
*/
#ifndef TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_
#define TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../utils.cuh"
#include "../fp8/gated_fp8.cuh"
#include "../mxfp8/gated_mxfp8.cuh"
namespace transformer_engine {
namespace dispatch {
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p,
cudaStream_t stream) {
const Tensor input = *convertNVTETensorCheck(nvte_input);
Tensor *output = convertNVTETensorCheck(nvte_output);
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output", /*allow_empty=*/false);
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim() / 2;
NVTE_CHECK(input.flat_last_dim() % 2 == 0,
"Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [",
input.flat_first_dim(), ", ", input.flat_last_dim(), "].");
NVTE_CHECK(output->flat_last_dim() == cols,
"Wrong output shape. Expected (after flattening) [*, ", cols, "], got [",
output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(output->has_data() || output->has_columnwise_data(),
"Either rowwise or columnwise output data need to be allocated.");
switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
if (use_tma_kernels) {
Tensor dummy_grad_tensor;
fp8::cast_gated_tma</*IS_BWD=*/false, ParamOP, ActOP, nullptr>(input, dummy_grad_tensor,
output, p, stream);
} else {
fp8::cast_gated_fwd<ParamOP, ActOP>(input, output, p, stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
NVTE_CHECK(cols % 32 == 0,
"Invalid input shape. Expected the last dimension to be "
"divisible by 32, but got ",
cols, ".");
if (output->has_data()) {
NVTE_CHECK(is_fp8_dtype(output->data.dtype),
"The type of the output tensor should be FP8.");
}
if (output->has_columnwise_data()) {
NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype),
"The type of the columnwise output tensor should be FP8.");
}
NVTE_CHECK(is_supported_by_CC_100(),
"Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+");
Tensor dummy_grad_tensor;
mxfp8::quantize_gated</*IS_BWD=*/false, ParamOP, ActOP, nullptr>(input, dummy_grad_tensor,
output, p, stream);
break;
}
default:
NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + ".");
}
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input,
NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) {
const Tensor &grad = *(convertNVTETensorCheck(nvte_grad));
const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input);
Tensor *output = convertNVTETensorCheck(nvte_output);
CheckInputTensor(grad, "grad");
CheckInputTensor(gated_input, "gated_input");
CheckOutputTensor(*output, "output", /*allow_empty=*/false);
NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ",
gated_input.flat_last_dim(), ".");
const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision.");
NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(grad.flat_first_dim() == rows,
"Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [",
grad.flat_first_dim(), ", ", grad.flat_last_dim(), "].");
NVTE_CHECK(grad.flat_last_dim() == cols,
"Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [",
grad.flat_first_dim(), ", ", grad.flat_last_dim(), "].");
NVTE_CHECK(output->has_data() || output->has_columnwise_data(),
"Either rowwise or columnwise output data need to be allocated.");
NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [",
rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(output->flat_last_dim() == cols * 2,
"Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [",
output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(gated_input.data.shape == output->data.shape,
"Gated input and output shapes must match. Input shape: ", gated_input.data.shape,
", output shape: ", output->data.shape, ".");
switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
if (use_tma_kernels) {
fp8::cast_gated_tma</*IS_BWD=*/true, ParamOP, ActOP, DActOP>(gated_input, grad, output, p,
stream);
} else {
fp8::cast_gated_bwd<ParamOP, ActOP, DActOP>(gated_input, grad, output, p, stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
NVTE_CHECK(cols % 32 == 0,
"Invalid input shape. Expected the last dimension to be "
"divisible by 32, but got ",
cols, ".");
if (output->has_data()) {
NVTE_CHECK(is_fp8_dtype(output->data.dtype),
"The type of the output tensor should be FP8.");
}
if (output->has_columnwise_data()) {
NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype),
"The type of the columnwise output tensor should be FP8.");
}
NVTE_CHECK(is_supported_by_CC_100(),
"Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+");
mxfp8::quantize_gated</*IS_BWD=*/true, ParamOP, ActOP, DActOP>(gated_input, grad, output, p,
stream);
break;
}
default:
NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + ".");
}
}
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file quantize.cuh
* \brief Quantize dispatcher.
*/
#ifndef TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_
#define TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../transpose/cast_transpose.h"
#include "../../util/vectorized_pointwise.h"
#include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh"
#include "../nvfp4/quantize_nvfp4.cuh"
#include "../nvfp4/quantize_transpose_nvfp4.cuh"
namespace transformer_engine {
namespace dispatch {
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
const Tensor *input_tensor = convertNVTETensorCheck(input);
Tensor *output_tensor = convertNVTETensorCheck(output);
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Check for unsupported options
if (quant_config_cpp.stochastic_rounding) {
NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Stochastic rounding is only supported for NVFP4 quantization.");
}
NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(),
"Either rowwise or columnwise output data need to be allocated.");
// Dispatch to quantization kernel depending on data format
switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
const Tensor *dummy_input_tensor = nullptr;
Tensor *dummy_dbias_tensor = nullptr;
Tensor *dummy_workspace_tensor = nullptr;
if (output_tensor->has_columnwise_data()) {
const char *NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");
if(NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1'){
NVTE_CHECK(false,
"NVTE_INT8_SIM_FP8_TENSORWISE need not be transposed!");
}
NVTE_CHECK(output_tensor->has_data(),
"Quantizing in only the columnwise direction not supported yet!");
if constexpr (!IS_ACT) {
cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream);
} else {
cast_transpose_fused</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, float, ParamOP, OP>(
*input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor,
dummy_workspace_tensor, stream);
}
} else if (output_tensor->has_data()) {
fp8::quantize</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, ParamOP, OP>(
*input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor,
dummy_workspace_tensor, stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
const Tensor *dummy_input_tensor = nullptr;
Tensor *dummy_dbias_tensor = nullptr;
Tensor *dummy_workspace_tensor = nullptr;
mxfp8::quantize</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, ParamOP, OP>(
*input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor,
dummy_workspace_tensor, stream);
break;
}
case NVTE_NVFP4_1D_SCALING: {
NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING");
// Check tensors
CheckNoopTensor(*noop_tensor, "cast_noop");
CheckInputTensor(*input_tensor, "input");
CheckOutputTensor(*output_tensor, "output", false);
// Choose kernel
int32_t rows = input_tensor->flat_first_dim();
int32_t cols = input_tensor->flat_last_dim();
auto dtype = input_tensor->dtype();
bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) &&
(cols % 32 == 0) && output_tensor->has_data();
// Launch NVFP4 quantize kernel
if (use_optimized_kernel) {
if (quant_config_cpp.nvfp4_2d_quantization) {
nvfp4::quantize_transpose</*use_2d_quantization=*/true>(
*input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
} else {
nvfp4::quantize_transpose</*use_2d_quantization*/ false>(
*input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
}
} else {
auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax
: output_tensor->columnwise_amax;
quantize_transpose_vector_blockwise_fp4(
/*input=*/input_tensor->data, /*global_amax=*/global_amax,
/*scale_inv=*/output_tensor->scale_inv,
/*scale_inv_t=*/output_tensor->columnwise_scale_inv,
/*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data,
/*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(),
/*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false,
/*swizzled_scale=*/false,
/*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding,
/*rng_state=*/quant_config_cpp.rng_state,
/*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization,
/*noop_tensor=*/noop_tensor->data, /*stream=*/stream);
}
break;
}
case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D");
bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp.amax_epsilon;
quantize_transpose_square_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales,
/*noop_tensor=*/noop_tensor->data, stream);
break;
}
case NVTE_BLOCK_SCALING_1D: {
// TODO(kwyss): IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D");
bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp.amax_epsilon;
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) {
bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
}
if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
}
quantize_transpose_vector_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, noop_tensor->data, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
}
}
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETensor output,
NVTETensor dbias, NVTETensor workspace,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
const Tensor *grad_tensor = convertNVTETensorCheck(grad);
const Tensor *input_tensor = convertNVTETensor(input);
Tensor *output_tensor = convertNVTETensorCheck(output);
Tensor *dbias_tensor = convertNVTETensor(dbias);
Tensor *workspace_tensor = convertNVTETensor(workspace);
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Check for unsupported options
if (quant_config_cpp.stochastic_rounding) {
NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Stochastic rounding is only supported for NVFP4 quantization.");
}
NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(),
"Either rowwise or columnwise output data need to be allocated.");
// Dispatch to quantization kernel depending on data format
switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
if (output_tensor->has_columnwise_data()) {
const char *NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");
if(NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1'){
NVTE_CHECK(false,
"NVTE_INT8_SIM_FP8_TENSORWISE need not be transposed!");
}
NVTE_CHECK(output_tensor->has_data(),
"Quantizing in only the columnwise direction not supported yet!");
if constexpr (!IS_DBIAS && !IS_DACT) {
cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream);
} else {
cast_transpose_fused<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, float, ParamOP, OP>(
*grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream);
}
} else if (output_tensor->has_data()) {
fp8::quantize<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, ParamOP, OP>(
*grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor,
stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
mxfp8::quantize<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, ParamOP, OP>(
*grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor,
stream);
break;
}
case NVTE_NVFP4_1D_SCALING: {
NVTE_CHECK((!IS_DBIAS && !IS_DACT),
"IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING");
// Check tensors
CheckNoopTensor(*noop_tensor, "cast_noop");
CheckInputTensor(*grad_tensor, "input");
CheckOutputTensor(*output_tensor, "output", false);
// Choose kernel
int32_t rows = grad_tensor->flat_first_dim();
int32_t cols = grad_tensor->flat_last_dim();
auto dtype = grad_tensor->dtype();
bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) &&
(cols % 32 == 0) && output_tensor->has_data();
// Launch NVFP4 quantize kernel
if (use_optimized_kernel) {
if (quant_config_cpp.nvfp4_2d_quantization) {
nvfp4::quantize_transpose</*use_2d_quantization=*/true>(
*grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
} else {
nvfp4::quantize_transpose</*use_2d_quantization*/ false>(
*grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
}
} else {
auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax
: output_tensor->columnwise_amax;
quantize_transpose_vector_blockwise_fp4(
/*input=*/grad_tensor->data, /*global_amax=*/global_amax,
/*scale_inv=*/output_tensor->scale_inv,
/*scale_inv_t=*/output_tensor->columnwise_scale_inv,
/*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data,
/*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(),
/*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false,
/*swizzled_scale=*/false,
/*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding,
/*rng_state=*/quant_config_cpp.rng_state,
/*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization,
/*noop_tensor=*/noop_tensor->data, /*stream=*/stream);
}
break;
}
case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT),
"IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D");
bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp.amax_epsilon;
quantize_transpose_square_blockwise(
grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales,
/*noop_tensor=*/noop_tensor->data, stream);
break;
}
case NVTE_BLOCK_SCALING_1D: {
// TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT),
"IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D");
bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp.amax_epsilon;
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) {
bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
}
if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
}
quantize_transpose_vector_blockwise(
grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, noop_tensor->data, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
}
}
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file dequantize_fp8.cuh
* \brief CUDA kernels to dequantize from FP8.
*/
#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_
#define TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/vectorized_pointwise.h"
#include "../../utils.cuh"
namespace transformer_engine {
namespace dispatch {
namespace fp8 {
struct DequantizeParam {
const float *scale_inv;
};
__device__ inline float dequantize_func(float value, const DequantizeParam &param) {
return value * (*(param.scale_inv));
}
inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(OType);
DequantizeParam p; p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
VectorizedUnaryKernelLauncher<nvec, DequantizeParam, dequantize_func>(
reinterpret_cast<const IType *>(input.data.dptr), nullptr,
reinterpret_cast<OType *>(output->data.dptr), nullptr, nullptr, nullptr, N, p,
stream);); // NOLINT(*)
); // NOLINT(*)
}
} // namespace fp8
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file gated_fp8.cuh
* \brief CUDA kernels to cast to FP8 with gated activations.
*/
#ifndef TRANSFORMER_ENGINE_GATED_FP8_CUH_
#define TRANSFORMER_ENGINE_GATED_FP8_CUH_
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../util/vectorized_pointwise.h"
#include "../../utils.cuh"
namespace transformer_engine {
namespace dispatch {
namespace fp8 {
namespace kernel {
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 512;
constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X;
constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128
constexpr size_t BUFFERS_NUM = 2;
constexpr size_t BUFFER_DIM_Y = 32;
constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128
constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32
constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128
constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4
constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32
static_assert(ITERATIONS >= 1);
#ifndef __HIP_PLATFORM_AMD__
template <bool IS_BWD, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &), typename IType, typename OType>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad,
const __grid_constant__ CUtensorMap tensor_map_input_act,
const __grid_constant__ CUtensorMap tensor_map_input_gate,
const __grid_constant__ CUtensorMap tensor_map_output_act,
const __grid_constant__ CUtensorMap tensor_map_output_gate,
float *const amax_ptr, float *const scale_inv_ptr,
const float *const scale_ptr, const size_t rows, const size_t cols,
const ParamOP p) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X;
const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
const size_t thread_offset_Y = tid_Y;
const size_t thread_offset_X = tid_X;
float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X;
constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t grad_mem = IS_BWD ? buff_size_aligned_in : 0;
constexpr size_t in_act_mem = buff_size_aligned_in;
constexpr size_t in_gate_mem = buff_size_aligned_in;
constexpr size_t in_mem = in_act_mem + in_gate_mem;
constexpr size_t out_act_mem = buff_size_aligned_out;
constexpr size_t in_transaction_size = buff_elems * sizeof(IType);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_grad_sh = reinterpret_cast<IType *>(dshmem);
IType *in_act_sh = reinterpret_cast<IType *>(dshmem + grad_mem);
IType *in_gate_sh = reinterpret_cast<IType *>(dshmem + grad_mem + in_act_mem);
OType *out_act_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem);
OType *out_gate_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_act_mem);
const uint64_t *TMAP_grad_in = reinterpret_cast<const uint64_t *>(&tensor_map_grad);
const uint64_t *TMAP_in_act = reinterpret_cast<const uint64_t *>(&tensor_map_input_act);
const uint64_t *TMAP_in_gate = reinterpret_cast<const uint64_t *>(&tensor_map_input_gate);
const uint64_t *TMAP_output_act = reinterpret_cast<const uint64_t *>(&tensor_map_output_act);
const uint64_t *TMAP_output_gate = reinterpret_cast<const uint64_t *>(&tensor_map_output_gate);
const bool is_master_thread = (threadIdx.x == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[ITERATIONS];
initialize_barriers<ITERATIONS, THREADS_PER_CHUNK>(mbar, is_master_thread);
int parity = 0;
// Prefetch data of the first stage
if constexpr (IS_BWD) {
copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh,
TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate,
chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0],
is_master_thread);
} else {
copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh,
TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0],
is_master_thread);
}
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
const size_t buff = it % BUFFERS_NUM;
const size_t next_it = it + 1;
if (next_it < ITERATIONS) {
const size_t next_buff = next_it % BUFFERS_NUM;
const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_BWD) {
copy_2d_to_sharedx3(
&in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y,
&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y,
&in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y,
in_transaction_size, &mbar[next_it], is_master_thread);
} else {
copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x,
chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate,
chunk_it_offset_x, chunk_it_offset_y, in_transaction_size,
&mbar[next_it], is_master_thread);
}
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[it], parity);
IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems;
IType *in_act_sh_curr = in_act_sh + buff * buff_elems;
IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems;
OType *out_act_sh_curr = out_act_sh + buff * buff_elems;
OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems;
#pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y;
const size_t shmem_offset_x = thread_offset_X;
const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1;
}
if constexpr (IS_BWD) {
float grad_elt = static_cast<float>(in_grad_sh_curr[shmem_idx]);
const float x = act_elt;
float act_x;
float dact_x;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
const float x = min(act_elt, p.limit);
const float s = sigmoidf(p.alpha * x);
act_x = x * s;
if (act_elt <= p.limit) {
dact_x = s + s * (1 - s) * p.alpha * x;
} else {
dact_x = 0.0f;
}
} else {
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, p);
dact_x = DActOP(x, p);
}
}
float after_dact = dact_x * grad_elt * gate_elt;
float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dact);
out_gate_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dgate);
amax = fmaxf(amax, fabsf(after_dact));
amax = fmaxf(amax, fabsf(after_dgate));
} else {
const float after_act = ActOP(act_elt, p) * gate_elt;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_act);
amax = fmaxf(amax, fabsf(after_act));
}
}
// Wait for shared memory writes to be visible to TMA engine (cross-proxy fence)
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
// dGeLU
ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x,
chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_act_sh_curr));
if constexpr (IS_BWD) {
// dGate
ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_gate_sh_curr));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<BUFFERS_NUM - 1>();
}
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
amax = reduce_max<THREADS_PER_CHUNK / THREADS_PER_WARP>(amax, warp_id);
// Update the global amax
if (is_master_thread) {
atomicMaxFloat(amax_ptr, amax);
}
}
// Update scale-inverse
if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) {
reciprocal<float>(scale_inv_ptr, scale);
}
// Destroy the barriers. This invalidates the memory region of the barrier.
// If further computations were to take place in the kernel, this allows the
// memory location of the shared memory barrier to be reused.
if (is_master_thread) {
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
ptx::mbarrier_invalid(&mbar[it]);
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
#endif
} // namespace kernel
template <bool IS_BWD, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_gated_tma(const Tensor &gated_input, const Tensor &grad, Tensor *output, ParamOP &p,
cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
assert(false);
#else
using namespace kernel;
checkCuDriverContext(stream);
NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function.");
const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_BWD ? 2 : 1) * cols;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
float *const scale_inv_ptr = reinterpret_cast<float *>(output->scale_inv.dptr);
float *const scale_ptr = reinterpret_cast<float *>(output->scale.dptr);
const dim3 block_dim(THREADS_PER_CHUNK);
const dim3 grid_dim(blocks_X, blocks_Y);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
gated_input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_grad{};
alignas(64) CUtensorMap tensor_map_input_act{};
alignas(64) CUtensorMap tensor_map_input_gate{};
alignas(64) CUtensorMap tensor_map_output_act{};
alignas(64) CUtensorMap tensor_map_output_gate{};
if constexpr (IS_BWD) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X,
cols, 0, typeToNumBits(gated_input.dtype()));
}
const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, cols,
typeToNumBits(output->dtype()));
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
const size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in;
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out;
const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) +
(out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT;
auto kernel = cast_fp8_gated_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
shmem_size));
kernel<<<grid_dim, block_dim, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act,
tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p);
NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
#endif
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void cast_gated_fwd(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->dtype(), OType,
constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, fp32, ParamOP, ActOP>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), input.flat_first_dim(),
output->flat_last_dim(), p, stream);); // NOLINT(*)
); // NOLINT(*)
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_gated_bwd(const Tensor &input, const Tensor &grad, Tensor *output, ParamOP &p,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->dtype(), OType,
constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, fp32, ParamOP, ActOP, DActOP>(
reinterpret_cast<const IType *>(grad.data.dptr),
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), grad.flat_first_dim(),
grad.flat_last_dim(), p, stream);); // NOLINT(*)
); // NOLINT(*)
}
} // namespace fp8
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GATED_FP8_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file quantize_fp8.cuh
* \brief CUDA kernels to quantize to FP8.
*/
#ifndef TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_
#define TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
#include <cfloat>
#include <cstddef>
#include <cstdint>
#include <limits>
#include "../../common.h"
#include "../../transpose/cast_transpose.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../util/vectorized_pointwise.h"
#include "../../utils.cuh"
#include "../core/common.cuh"
namespace transformer_engine {
namespace dispatch {
namespace fp8 {
namespace quantize_2D_kernel {
constexpr size_t FP8_CHUNK_DIM_Y = 128;
constexpr size_t FP8_CHUNK_DIM_X = 128;
constexpr size_t FP8_THREADS_PER_CHUNK = 128;
constexpr size_t FP8_BUFFERS_NUM = 2;
constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1;
static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM);
constexpr size_t FP8_BUFFER_DIM_Y = 16;
constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128
constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16
constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128
constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16
constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16
static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM);
#ifndef __HIP_PLATFORM_AMD__
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, typename OType>
__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_act_input,
const __grid_constant__ CUtensorMap tensor_map_output,
float *const dbias_workspace, float *const amax_ptr,
float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows,
const size_t cols) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X;
const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK;
const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK;
const size_t thread_offset_Y = tid_Y;
const size_t thread_offset_X = tid_X;
const size_t dbias_offset_Y = blockIdx.y + tid_Y;
const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X;
const bool col_out_of_bounds = my_column >= cols;
const size_t dbias_stride = cols;
float partial_dbias = 0.f;
float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
// The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned
__shared__ alignas(TMA_SHMEM_ALIGNMENT)
IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
__shared__ alignas(TMA_SHMEM_ALIGNMENT)
IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
__shared__ alignas(TMA_SHMEM_ALIGNMENT)
OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS];
initialize_barriers<FP8_ITERATIONS, FP8_THREADS_PER_CHUNK>(mbar, is_master_thread);
int parity = 0;
const size_t chunk_offset_Y = block_offset_Y;
const size_t chunk_offset_X = block_offset_X;
#pragma unroll
for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) {
const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y;
const size_t chunk_stage_offset_X = chunk_offset_X;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X,
chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input,
chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size,
&mbar[prefetch_buff], is_master_thread);
} else {
copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X,
chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff],
is_master_thread);
}
}
#pragma unroll
for (int iter = 0; iter < FP8_ITERATIONS; ++iter) {
const size_t buff = iter % FP8_BUFFERS_NUM;
const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM;
const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y;
if (next_iter < FP8_ITERATIONS) {
const size_t next_buff = next_iter % FP8_BUFFERS_NUM;
const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input,
chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter],
is_master_thread);
} else {
copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread);
}
}
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[iter], parity);
#pragma unroll
for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) {
const size_t stage_offset_Y = stage;
const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y;
const size_t shmem_offset_x = thread_offset_X;
const size_t row = row_base + shmem_offset_y;
const bool row_out_of_bounds = row >= rows;
const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds;
float elt = static_cast<float>(in_sh[buff][shmem_offset_y][shmem_offset_x]);
if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in_sh[buff][shmem_offset_y][shmem_offset_x]);
elt *= OP(act_in_elt, {});
}
if constexpr (IS_DBIAS) {
if constexpr (IS_DACT) {
if (!out_of_bounds) {
partial_dbias += elt;
}
} else {
// If no activation, elt is 0 so we can safely do this
partial_dbias += elt;
}
}
__builtin_assume(amax >= 0);
if (IS_DACT) {
if (!out_of_bounds) {
amax = fmaxf(amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
amax = fmaxf(amax, fabsf(elt));
}
out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast<OType>(elt * scale);
}
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), chunk_it_offset_x,
chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_sh[buff]));
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<FP8_PREFETCH_BUFFERS_NUM>();
}
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
parity ^= 1;
if constexpr (IS_DBIAS) {
const size_t dbias_offset_X = my_column;
const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X;
if (!col_out_of_bounds) {
dbias_workspace[dbias_offset] = partial_dbias;
}
}
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
amax = reduce_max<FP8_THREADS_PER_CHUNK / THREADS_PER_WARP>(amax, warp_id);
// Update the global amax
if (is_master_thread) {
atomicMaxFloat(amax_ptr, amax);
}
}
// Update scale-inverse
if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) {
reciprocal<float>(scale_inv_ptr, scale);
}
destroy_barriers<FP8_ITERATIONS>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
#endif
} // namespace quantize_2D_kernel
namespace quantize_1D_kernel {
using namespace quantize_2D_kernel;
constexpr size_t CHUNKS_PER_BLOCK = 128;
constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK;
constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK;
constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE;
constexpr size_t CHUNKS_PER_ITERATION = 32;
constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE;
constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION;
constexpr size_t SHMEM_BUFFERS = 2;
static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0);
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &), typename IType,
typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr,
float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK;
const IType *input = input_ptr + block_offset;
OType *output = output_ptr + block_offset;
float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
// The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned
__shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM];
__shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM];
constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS;
constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS;
const bool is_master_thread = (threadIdx.x == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[ITERATIONS];
initialize_barriers<ITERATIONS, THREADS_PER_BLOCK>(mbar, is_master_thread);
int parity = 0;
copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread);
#pragma unroll
for (int iter = 0; iter < ITERATIONS; ++iter) {
const size_t buff = iter % SHMEM_BUFFERS;
const size_t it_offset = iter * SHMEM_DIM;
const size_t next_iter = iter + 1;
const size_t next_buff = next_iter % SHMEM_BUFFERS;
const size_t next_iter_offset = next_iter * SHMEM_DIM;
if (next_iter < ITERATIONS) {
copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN,
&(mbar[next_iter]), is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[iter], parity);
#pragma unroll
for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) {
const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x;
float elt = static_cast<float>(in_sh[buff][shmem_offset]);
if constexpr (IS_ACT) {
elt = OP(elt, {});
}
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(elt));
out_sh[buff][shmem_offset] = static_cast<OType>(elt * scale);
}
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
ptx::cp_async_bulk_tensor_1d_shared_to_global(
reinterpret_cast<uint64_t *>(output + it_offset),
reinterpret_cast<uint64_t *>(&out_sh[buff]), transaction_size_OUT);
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<1>();
}
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
amax = reduce_max<THREADS_PER_BLOCK / THREADS_PER_WARP>(amax, warp_id);
// Update the global amax
if (is_master_thread) {
atomicMaxFloat(amax_ptr, amax);
}
}
// Update scale-inverse
if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) {
reciprocal<float>(scale_inv_ptr, scale);
}
destroy_barriers<ITERATIONS>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace quantize_1D_kernel
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void quantize_1D(const Tensor &input, Tensor *output, cudaStream_t stream) {
using namespace quantize_1D_kernel;
const size_t N = product(input.data.shape);
const bool isFullTile = (N % ELEMS_PER_BLOCK == 0);
NVTE_CHECK(isFullTile, "Only full tiles are supported.");
NVTE_CHECK(is_fp8_dtype(output->dtype()) || is_int8_dtype(output->dtype()), "Output must have FP8 or int8 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
const size_t chunks = DIVUP(N, CHUNK_SIZE);
const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK);
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
float *const scale_inv_ptr = reinterpret_cast<float *>(output->scale_inv.dptr);
const float *const scale_ptr = reinterpret_cast<float *>(output->scale.dptr);
const dim3 block(THREADS_PER_BLOCK);
const dim3 grid(blocks);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(
output->dtype(), OType,
const IType *input_ptr = reinterpret_cast<const IType *>(input.data.dptr);
OType *output_ptr = reinterpret_cast<OType *>(output->data.dptr);
cast_fp8_1D_kernel<IS_ACT, ParamOP, OP, IType, OType><<<grid, block, 0, stream>>>(
input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*)
); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias,
Tensor *workspace, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
assert(false);
#else
using namespace quantize_2D_kernel;
checkCuDriverContext(stream);
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y);
const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X);
const size_t blocks_Y = chunks_Y;
const size_t blocks_X = chunks_X;
const size_t dbias_rows = blocks_Y;
const size_t dbias_cols = cols;
NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
if constexpr (IS_DBIAS) {
NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{cols}, "Wrong shape of DBias.");
NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor.");
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {dbias_rows, dbias_cols};
workspace->data.dtype = DType::kFloat32;
return;
}
}
float *const workspace_ptr = IS_DBIAS ? reinterpret_cast<float *>(workspace->data.dptr) : nullptr;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
float *const scale_inv_ptr = reinterpret_cast<float *>(output->scale_inv.dptr);
float *const scale_ptr = reinterpret_cast<float *>(output->scale.dptr);
const dim3 block(FP8_THREADS_PER_CHUNK);
const dim3 grid(blocks_X, blocks_Y);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->data.dtype, OType,
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype));
if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype));
}
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype));
cast_fp8_2D_kernel<IS_DBIAS, IS_DACT, ParamOP, OP, IType, OType>
<<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_act_input, tensor_map_output,
workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows,
cols);
NVTE_CHECK_CUDA(cudaGetLastError());
if constexpr (IS_DBIAS) {
common::reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}); // NOLINT(*)
); // NOLINT(*)
#endif
}
namespace detail {
using Empty = transformer_engine::Empty;
__device__ inline float identity(float value, const Empty &) { return value; }
} // namespace detail
#ifdef __HIP_PLATFORM_AMD__
template <typename ParamOP, float (*OP)(float, const ParamOP &)>
struct KernelType
{
static constexpr auto op = OP;
};
#endif
template <typename ParamOP, float (*OP)(float, const ParamOP &)>
void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output,
cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
using kernel = KernelType<ParamOP, OP>;
constexpr float (*UnaryOP)(float, const ParamOP &) = (kernel::op == nullptr) ? KernelType<ParamOP, &detail::identity>::op : kernel::op;
#else
constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP;
#endif
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT_WITH_INT8(
output->data.dtype, OType,
if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, ParamOP, UnaryOP>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<const fp32 *>(noop->data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), N, {}, stream);
} else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*)
); // NOLINT(*)
}
template <typename ParamOP, float (*OP)(float, const ParamOP &)>
void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output,
cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
using kernel = KernelType<ParamOP, OP>;
constexpr float (*UnaryOP)(float, const ParamOP &) = (kernel::op == nullptr) ? KernelType<ParamOP, &detail::identity>::op : kernel::op;
#else
constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP;
#endif
const size_t N = product(input->data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input->data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT_WITH_INT8(
output->data.dtype, OType,
if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, ParamOP, UnaryOP>(
reinterpret_cast<const IType *>(grad.data.dptr),
reinterpret_cast<const IType *>(input->data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), N, {}, stream);
} else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*)
); // NOLINT(*)
}
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)>
void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output,
Tensor *dbias, Tensor *workspace, cudaStream_t stream) {
using namespace quantize_1D_kernel;
CheckNoopTensor(*noop, "cast_noop");
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
if constexpr (IS_DBIAS) {
NVTE_CHECK(dbias != nullptr);
CheckOutputTensor(*dbias, "dbias");
}
if constexpr (IS_DACT) {
NVTE_CHECK(act_input != nullptr);
CheckInputTensor(*act_input, "activation_input");
NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match.");
NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match.");
}
NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
// Supported by the Arch >= 10.0
if (is_supported_by_CC_100()) {
if (!IS_DBIAS && !IS_DACT) {
if (common::full_tile_1D_tensor(output, ELEMS_PER_BLOCK) && is_fp8_dtype(output->dtype()) &&
is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) &&
is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) {
// Aligned AND FP8
quantize_1D<IS_ACT, ParamOP, OP>(input, output, stream);
} else {
// Unaligned
CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
}
} else if (!IS_DBIAS && IS_DACT) {
if (common::dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) &&
is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) &&
is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) &&
is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) {
// Aligned AND FP8 (+dAct)
quantize_2D<IS_DBIAS, IS_DACT, ParamOP, OP>(input, act_input, output, dbias, workspace,
stream);
} else {
// Unaligned
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, output, stream);
}
} else {
quantize_2D<IS_DBIAS, IS_DACT, ParamOP, OP>(input, act_input, output, dbias, workspace,
stream);
}
} else {
if (IS_DBIAS) {
// zhongboz: should we just ignore IS_ACT here?
NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) +
" or IS_DBIAS=true" + " on GPU with compute capability < 10.0.");
}
if (!IS_DACT) {
CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
} else {
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, output, stream);
}
}
}
} // namespace fp8
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment