"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "58c3ac80fab933db62559585ce2592951b3f14df"
Commit 0a5016b1 authored by wenjh's avatar wenjh
Browse files

Merge nv release_v2.9


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 063ef88d 70f53666
......@@ -103,8 +103,10 @@ class TestDistributedSoftmax:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec))
x_named_sharding = NamedSharding(mesh, x_pspec)
mask_named_sharding = NamedSharding(mesh, mask_pspec)
x_ = jax.device_put(x, x_named_sharding)
mask_ = jax.device_put(mask, mask_named_sharding)
with warnings.catch_warnings(record=True) as warns:
try:
......@@ -116,8 +118,8 @@ class TestDistributedSoftmax:
grad_args=(0,),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, mask_pspec),
out_shardings=(None, (x_pspec,)),
in_shardings=(x_named_sharding, mask_named_sharding),
out_shardings=(None, x_named_sharding),
)
except AssertionError as err:
# Softmax should still produce the correct numerical result with
......
......@@ -378,14 +378,14 @@ class FusedAttnRunner:
pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
# TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
if (
get_device_compute_capability(0) == 100
get_device_compute_capability(0) >= 100
and self.dropout_prob == 0.1
and self.attn_bias_type is not AttnBiasType.NO_BIAS
):
pytest.skip(
"For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
"For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
......
......@@ -3,11 +3,13 @@
# See LICENSE for license information.
import unittest
from functools import partial
import flax
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from utils import assert_allclose
from transformer_engine.common.recipe import (
......@@ -24,15 +26,51 @@ from transformer_engine.jax.quantize import (
ScalingMode,
update_collections,
TensorSource,
QuantizerFactory,
QuantizeLayout,
)
from transformer_engine.jax.quantize.helper import _format2dtypes
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
from transformer_engine.jax.flax.module import TransformerEngineBase
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
def quantizer_check_vjp(outer_quantizer_set, assertion_func, x):
"""Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""
# Define a function with a custom VJP (vector-Jacobian product)
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def quantizer_check(inner_quantizer_set, assertion_func, x):
return quantizer_check_fwd(inner_quantizer_set, assertion_func, x)
def quantizer_check_fwd(inner_quantizer_set, assertion_func, x):
assertion_func(inner_quantizer_set.x, TensorSource.X)
assertion_func(inner_quantizer_set.kernel, TensorSource.KERNEL)
assertion_func(inner_quantizer_set.dgrad, TensorSource.DGRAD)
return x
def quantizer_check_bwd(ctx, g):
return (g,)
quantizer_check.defvjp(quantizer_check_fwd, quantizer_check_bwd)
return quantizer_check(outer_quantizer_set, assertion_func, x)
class TestModule(TransformerEngineBase):
"""A simple module to test quantizer creation and reconstruction across VJP boundaries."""
# Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None
assertion_func: callable
@nn.compact
def __call__(self, x):
quantizer_set = self.generate_quantizer_set()
return quantizer_check_vjp(quantizer_set, self.assertion_func, x)
class TestHelper(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
......@@ -89,12 +127,43 @@ class TestFP8Functions(unittest.TestCase):
for tensor_source in TensorSource:
target_scaling_mode = (
ScalingMode.NVFP4_2D_SCALING
if tensor_source == TensorSource.KERNEL
if (not test.disable_2d_quantization) and tensor_source == TensorSource.KERNEL
else ScalingMode.NVFP4_1D_SCALING
)
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode
)
self.assertEqual(
get_quantize_config().DISABLE_STOCHASTIC_ROUNDING, test.disable_stochastic_rounding
)
self.assertEqual(get_quantize_config().DISABLE_RHT, test.disable_rht)
self.assertEqual(
get_quantize_config().DISABLE_2D_QUANTIZATION, test.disable_2d_quantization
)
def _compare_nvfp4_scaling_quantizers(self, test):
"""Check that the quantizers created have the expected stochastic rounding state and the state is preserved across VJP boundaries."""
def assertion_func(quantizer, tensor_source):
if test.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD:
self.assertIsNone(quantizer.stochastic_rounding_rng_state)
else:
self.assertIsNotNone(quantizer.stochastic_rounding_rng_state)
expected_rht = (
quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
and quantizer.q_layout in {QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE}
and not test.disable_rht
)
self.assertEqual(quantizer.use_rht, expected_rht)
x = jnp.ones((), dtype=jnp.float32)
test_module = TestModule(assertion_func=assertion_func)
param_key, sr_key = jax.random.split(jax.random.PRNGKey(0))
rngs = {"params": param_key, "sr_rng": sr_key}
variables = test_module.init(rngs, x)
jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs)
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_autocast_delayed_scaling(self):
......@@ -171,5 +240,16 @@ class TestFP8Functions(unittest.TestCase):
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_nvfp4_scaling(bs)
self._compare_nvfp4_scaling_quantizers(bs)
bs = NVFP4BlockScaling(
disable_stochastic_rounding=True,
disable_rht=True,
disable_2d_quantization=True,
)
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_nvfp4_scaling(bs)
self._compare_nvfp4_scaling_quantizers(bs)
self._check_default_state()
......@@ -248,6 +248,7 @@ def run_dpa_with_cp(
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit,
).cuda()
if config.softmax_type != "vanilla":
core_attn.softmax_offset.requires_grad = True
......@@ -308,6 +309,7 @@ def run_dpa_with_cp(
fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group)
else:
fp8_context = nullcontext()
max_logit = None
with fp8_context:
# q, k, v, out in FP8; dout in F16
out = core_attn(
......@@ -322,6 +324,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha,
)
if config.return_max_logit:
out, max_logit = out
if fp8_bwd and fp8_mha:
dout_fp8 = dout_quantizer(dout)
out.backward(dout_fp8)
......@@ -400,6 +404,7 @@ def run_dpa_with_cp(
fp8_context = nullcontext()
# run attention
max_logit_ = None
with fp8_context:
# q, k, v, out in FP8; dout in F16
out_ = core_attn(
......@@ -414,6 +419,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha,
)
if config.return_max_logit:
out_, max_logit_ = out_
if fp8_bwd and fp8_mha:
dout_fp8_ = dout_quantizer(dout_)
out_.backward(dout_fp8_)
......@@ -495,15 +502,15 @@ def run_dpa_with_cp(
)
atol, rtol, rmse_tol = get_tols(config, dtype)
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_]
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset]
names = ["out", "dq", "dk", "dv", "d_softmax_offset"]
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_]
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit]
names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"]
names_cp = [x + "_cp" for x in names]
names_no_cp = [x + "_no_cp" for x in names]
is_fp8 = dtype == "fp8"
for i, t in enumerate(tensors_no_cp):
if t is not None:
if "softmax_offset" not in names[i]:
if "softmax_offset" not in names[i] and "max_logit" not in names[i]:
if qkv_format == "bshd":
compare_and_assert(
t[:, 0],
......
......@@ -60,8 +60,16 @@ from utils import (
get_available_attention_backends,
)
# Check if hardware supports FP8
# Check if hardware supports FP8 attention.
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
fp8_attn_available, reason_for_no_fp8_attn = fp8_available, reason_for_no_fp8
device_compute_capability = get_device_compute_capability()
if fp8_available and (device_compute_capability < (9, 0) or device_compute_capability >= (12, 0)):
fp8_attn_available = False
reason_for_no_fp8_attn = (
"FP8 attention is not supported for compute capability ="
f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}"
)
# Reset RNG seed and states
seed = 1234
......@@ -130,6 +138,11 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
config.attn_mask_type = (
"padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
)
# Get backends
is_training = True
......@@ -171,7 +184,7 @@ def test_dot_product_attention(
# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
unfused_attn_fwd, unfused_max_logit, unfused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"UnfusedDotProductAttention",
......@@ -185,7 +198,7 @@ def test_dot_product_attention(
# FusedAttention backend
if fused_attn_supported:
if len(fused_attn_backends) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
fused_attn_fwd, fused_max_logit, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
......@@ -197,7 +210,7 @@ def test_dot_product_attention(
)
if len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
......@@ -208,7 +221,7 @@ def test_dot_product_attention(
is_training,
)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
......@@ -221,7 +234,7 @@ def test_dot_product_attention(
# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
flash_attn_fwd, _, flash_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FlashAttention",
......@@ -242,6 +255,8 @@ def test_dot_product_attention(
if unfused_attn_supported and fused_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
if config.return_max_logit:
torch.testing.assert_close(fused_max_logit, unfused_max_logit, **tols)
for i, _ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported:
......@@ -265,6 +280,33 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
model_configs_max_logit = {
# test: ModelConfig(b, sq, hq, dqk)
"max_logit_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
"max_logit_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
"max_logit_3": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
"max_logit_4": ModelConfig(
8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias"
),
"max_logit_5": ModelConfig(
8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0)
),
"max_logit_6": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_max_logit])
@pytest.mark.parametrize("model", model_configs_max_logit.keys())
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"])
def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with checkpointing"""
config = model_configs[model]
config.return_max_logit = True
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
......@@ -962,6 +1004,8 @@ def _run_dot_product_attention(
layout = layout.replace("d", "dqk")
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
# tensor: with padding tokens
# tensor_orig: without padding tokens
tensor_orig = tensor
if qkv_format == "thd" and pad_between_seqs:
tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
......@@ -1071,6 +1115,7 @@ def _run_dot_product_attention(
layer_number=1,
attention_type=config.attn_type,
softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit,
).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()
......@@ -1108,16 +1153,21 @@ def _run_dot_product_attention(
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
)
max_logit = None
if config.return_max_logit:
out, max_logit = out
if is_training:
out.backward(d_out)
d_softmax_offset = None
if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if is_training:
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, (None, None, None, d_softmax_offset)
return out, max_logit, (None, None, None, d_softmax_offset)
if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
......@@ -1146,14 +1196,18 @@ def _run_dot_product_attention(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset)
return (
out_orig,
max_logit,
(q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset),
)
else:
return out_orig, (None, None, None, d_softmax_offset)
return out_orig, max_logit, (None, None, None, d_softmax_offset)
else:
if is_training:
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, (None, None, None, d_softmax_offset)
return out, max_logit, (None, None, None, d_softmax_offset)
model_configs_te_layer = {
......@@ -1527,8 +1581,7 @@ model_configs_fp8_extra_state = {
}
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
......@@ -1690,8 +1743,7 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
......@@ -1927,8 +1979,7 @@ def _run_mha_fp8_vs_f16(
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
......@@ -2256,8 +2307,7 @@ models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
),
reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model):
......
......@@ -138,8 +138,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn = {
# test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_logit=True), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_logit=True), # MHA
"cp_1_2": ModelConfig(
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), # MHA
......@@ -184,7 +184,7 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"]
......
......@@ -45,11 +45,10 @@ from transformer_engine.pytorch import (
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states, get_available_attention_backends
from utils import ModelConfig, reset_rng_states
# Only run FP8 tests on supported devices.
......@@ -135,23 +134,6 @@ if torch.cuda.get_device_capability() == (9, 0):
use_cutlass_grouped_gemm.append(True)
def is_fused_attn_available(
config: ModelConfig,
dtype: torch.dtype,
qkv_layout="bshd_bshd_bshd",
is_training=True,
deterministic=False,
):
_, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
deterministic=deterministic,
)
return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends
def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
......@@ -872,8 +854,6 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@pytest.mark.parametrize("model", ["126m"])
def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model]
if not is_fused_attn_available(config, dtype, deterministic=True):
pytest.skip("No attention backend available.")
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
......@@ -920,10 +900,6 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model]
if not is_fused_attn_available(
config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True
):
pytest.skip("No attention backend available.")
te_gpt = TransformerLayer(
hidden_size=config.hidden_size,
......@@ -1035,10 +1011,6 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model]
if not is_fused_attn_available(
config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True
):
pytest.skip("No attention backend available.")
te_mha = MultiheadAttention(
config.hidden_size,
......
......@@ -205,6 +205,7 @@ class ModelConfig:
window_size: Tuple[int, int] = (-1, -1),
context_parallel: bool = False,
cp_comm_type: str = "p2p",
return_max_logit=False,
total_requests: int = None,
max_ctx_len: int = None,
num_layers: int = 1,
......@@ -233,6 +234,7 @@ class ModelConfig:
self.window_size = check_set_window_size(self.attn_mask_type, window_size)
self.context_parallel = context_parallel
self.cp_comm_type = cp_comm_type
self.return_max_logit = return_max_logit
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
......@@ -318,6 +320,7 @@ def get_available_attention_backends(
is_training=is_training,
inference_params=inference_params,
softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit,
)
(
use_flash_attention,
......
......@@ -29,35 +29,80 @@ endif()
# Language options
if(USE_CUDA)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G")
endif()
# Hide non-necessary symbols in shared object.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version")
# Transformer Engine library
project(transformer_engine LANGUAGES CUDA CXX)
# CUDA Toolkit
find_package(CUDAToolkit REQUIRED)
if (CUDAToolkit_VERSION VERSION_LESS 12.0)
message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}")
if (CUDAToolkit_VERSION VERSION_LESS 12.1)
message(FATAL_ERROR "CUDA 12.1+ is required, but found CUDA ${CUDAToolkit_VERSION}")
endif()
# Process GPU architectures
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
endif()
# Process CMAKE_CUDA_ARCHITECTURES to separate generic and specific architectures
set(NVTE_GENERIC_ARCHS)
set(NVTE_SPECIFIC_ARCHS)
# Check for architecture 100
list(FIND CMAKE_CUDA_ARCHITECTURES "100" arch_100_index)
if(NOT arch_100_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "100")
list(APPEND NVTE_GENERIC_ARCHS "100")
list(APPEND NVTE_SPECIFIC_ARCHS "100a")
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9)
list(APPEND NVTE_SPECIFIC_ARCHS "103a")
endif()
endif()
# Check for architecture 101 (if we see this we are in toolkit <= 12.9)
list(FIND CMAKE_CUDA_ARCHITECTURES "101" arch_101_index)
if(NOT arch_101_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "101")
list(APPEND NVTE_GENERIC_ARCHS "101")
list(APPEND NVTE_SPECIFIC_ARCHS "101a")
endif()
# Check for architecture 110 (if we see this we are in toolkit >= 13.0)
list(FIND CMAKE_CUDA_ARCHITECTURES "110" arch_110_index)
if(NOT arch_110_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "110")
list(APPEND NVTE_GENERIC_ARCHS "110")
list(APPEND NVTE_SPECIFIC_ARCHS "110f")
endif()
# Check for architecture 120
list(FIND CMAKE_CUDA_ARCHITECTURES "120" arch_120_index)
if(NOT arch_120_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "120")
list(APPEND NVTE_GENERIC_ARCHS "120")
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9)
list(APPEND NVTE_SPECIFIC_ARCHS "120f")
else()
list(APPEND NVTE_SPECIFIC_ARCHS "120a")
endif()
endif()
# cuDNN frontend API
set(CUDNN_FRONTEND_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include")
......@@ -135,139 +180,206 @@ endif()
# Configure Transformer Engine library
include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES)
set(transformer_engine_cpp_sources)
set(transformer_engine_cuda_sources)
set(transformer_engine_cuda_arch_specific_sources)
if(USE_CUDA)
list(APPEND transformer_engine_SOURCES
cudnn_utils.cpp
transformer_engine.cpp
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
activation/gelu.cu
dropout/dropout.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu
activation/swiglu.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/config.cpp
gemm/cublaslt_gemm.cu
gemm/cutlass_grouped_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu
fused_router/fused_moe_aux_loss.cu
fused_router/fused_score_for_moe_aux_loss.cu
fused_router/fused_topk_with_score_function.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/nvfp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)
list(APPEND transformer_engine_cpp_sources
cudnn_utils.cpp
transformer_engine.cpp
fused_attn/fused_attn.cpp
gemm/config.cpp
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/rmsnorm/rmsnorm_api.cpp
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/comm_gemm_overlap.cpp)
list(APPEND transformer_engine_cuda_sources
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
dropout/dropout.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
fused_attn/fused_attn_fp8.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/padding.cu
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu
fused_router/fused_moe_aux_loss.cu
fused_router/fused_score_for_moe_aux_loss.cu
fused_router/fused_topk_with_score_function.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers.cu)
list(APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
util/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu)
# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
list(APPEND transformer_engine_SOURCES ${transformer_engine_cuda_arch_specific_sources}
${transformer_engine_cuda_sources}
${transformer_engine_cpp_sources})
# Set compile options for CUDA sources with generic architectures
foreach(cuda_source IN LISTS transformer_engine_cuda_sources)
set(arch_compile_options)
foreach(arch IN LISTS NVTE_GENERIC_ARCHS)
list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}")
endforeach()
if(arch_compile_options)
set_property(
SOURCE ${cuda_source}
APPEND
PROPERTY
COMPILE_OPTIONS ${arch_compile_options}
)
endif()
endforeach()
# Set compile options for CUDA sources with specific architectures
foreach(cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources)
set(arch_compile_options)
foreach(arch IN LISTS NVTE_SPECIFIC_ARCHS)
list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}")
endforeach()
if(arch_compile_options)
set_property(
SOURCE ${cuda_source}
APPEND
PROPERTY
COMPILE_OPTIONS ${arch_compile_options}
)
endif()
endforeach()
if (NVTE_WITH_CUBLASMP)
list(APPEND transformer_engine_SOURCES
list(APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp)
endif()
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
else()
list(APPEND transformer_engine_SOURCES
cudnn_utils.cpp
transformer_engine.cpp
common.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu
dropout/dropout.cu
activation/relu.cu
activation/swiglu.cu
gemm/config.cpp
gemm/cublaslt_gemm.cu
gemm/hipblas_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu
fused_router/fused_moe_aux_loss.cu
fused_router/fused_score_for_moe_aux_loss.cu
fused_router/fused_topk_with_score_function.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)
list(APPEND transformer_engine_cpp_sources
cudnn_utils.cpp
transformer_engine.cpp
gemm/config.cpp
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/rmsnorm/rmsnorm_api.cpp
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/comm_gemm_overlap.cpp)
list(APPEND transformer_engine_cuda_sources
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
dropout/dropout.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
fused_attn/fused_attn_fp8.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/padding.cu
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu
fused_router/fused_moe_aux_loss.cu
fused_router/fused_score_for_moe_aux_loss.cu
fused_router/fused_topk_with_score_function.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers.cu)
list(APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
util/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu)
# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
list(APPEND transformer_engine_SOURCES ${transformer_engine_cuda_arch_specific_sources}
${transformer_engine_cuda_sources}
${transformer_engine_cpp_sources})
if (NVTE_WITH_CUBLASMP)
list(APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp)
......@@ -316,10 +428,12 @@ if (USE_CUDA)
CUDA::cublas
CUDA::cudart
CUDNN::cudnn_all)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE ${MATHDX_INCLUDE_DIR})
target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
target_include_directories(transformer_engine PRIVATE
${CUTLASS_INCLUDE_DIR}
......@@ -436,30 +550,36 @@ target_include_directories(transformer_engine PRIVATE
"${CMAKE_CURRENT_BINARY_DIR}/string_headers")
# Compiler options
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
set(nvte_sources_with_fast_math)
list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu)
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
set_source_files_properties(activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
util/cast.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
list(APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
util/cast.cu)
endif()
if(USE_CUDA)
foreach(cuda_source IN LISTS nvte_sources_with_fast_math)
set_property(
SOURCE ${cuda_source}
APPEND
PROPERTY
COMPILE_OPTIONS "--use_fast_math")
endforeach()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
else()
......
......@@ -8,22 +8,18 @@ import ctypes
import functools
import glob
import importlib
from importlib.metadata import version, metadata, PackageNotFoundError
import logging
from importlib.metadata import version, distribution, PackageNotFoundError
import os
from pathlib import Path
import platform
import subprocess
import sys
import sysconfig
from typing import Optional
from typing import Optional, Tuple
from torch.utils.cpp_extension import IS_HIP_EXTENSION
_logger = logging.getLogger(__name__)
@functools.lru_cache(maxsize=None)
def _is_pip_package_installed(package) -> bool:
def _is_package_installed(package) -> bool:
"""Check if the given package is installed via pip."""
# This is needed because we only want to return true
......@@ -31,12 +27,34 @@ def _is_pip_package_installed(package) -> bool:
# if it's importable in the current directory due to
# the presence of the shared library module.
try:
metadata(package)
distribution(package)
except PackageNotFoundError:
return False
return True
@functools.lru_cache(maxsize=None)
def _is_package_installed_from_wheel(package) -> bool:
"""Check if the given package is installed via PyPI."""
if not _is_package_installed(package):
return False
te_dist = distribution(package)
te_wheel_file = ""
for file_path in te_dist.files:
if file_path.name == "WHEEL":
te_wheel_file = te_dist.locate_file("") / file_path
if not te_wheel_file:
return False
with te_wheel_file.open("r") as f:
for line in f:
if line.startswith("Root-Is-Purelib:"):
return line.strip().split(":")[1].strip().lower() == "true"
return False
@functools.lru_cache(maxsize=None)
def _find_shared_object_in_te_dir(te_path: Path, prefix: str) -> Optional[Path]:
"""
......@@ -112,6 +130,19 @@ def _get_shared_object_file(library: str) -> Path:
)
def get_te_core_package_info() -> Tuple[bool, str, str]:
"""
Check if Tranformer Engine core package is installed.
Returns the module name and version if found.
"""
te_core_packages = ("transformer-engine-cu12", "transformer-engine-cu13")
for package in te_core_packages:
if _is_package_installed(package):
return True, package, version(package)
return False, "", ""
@functools.lru_cache(maxsize=None)
def load_framework_extension(framework: str) -> None:
"""
......@@ -130,39 +161,30 @@ def load_framework_extension(framework: str) -> None:
if framework == "torch":
extra_dep_name = "pytorch"
# Find the TE packages. The core and framework packages can only be installed via PyPI.
# For the `transformer-engine` package, we need to check explicity.
te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info()
te_framework_installed = _is_package_installed(module_name)
te_installed = _is_package_installed("transformer_engine")
te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine")
assert te_installed, "Could not find `transformer_engine`."
# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version.
if _is_pip_package_installed(module_name):
assert _is_pip_package_installed(
"transformer_engine"
), "Could not find `transformer-engine`."
assert _is_pip_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
# extension are all installed via PyPI and have matching versions.
if te_framework_installed:
assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package."
assert te_core_installed, "Could not find TE core package `transformer-engine-cu*`."
assert version(module_name) == version("transformer-engine") == te_core_version, (
"Transformer Engine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'"
f" v{version('transformer-engine')}, and {te_core_package_name}"
f" v{te_core_version}. Install transformer-engine using "
f"'pip3 install --no-build-isolation transformer-engine[{extra_dep_name}]==VERSION'"
)
# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
# Note: Should we error? This is a rare use case.
if _is_pip_package_installed("transformer-engine-cu12"):
if not _is_pip_package_installed(module_name):
_logger.info(
"Could not find package %s. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'",
module_name,
)
# After all checks are completed, load the shared object file.
spec = importlib.util.spec_from_file_location(module_name, _get_shared_object_file(framework))
solib = importlib.util.module_from_spec(spec)
......@@ -170,6 +192,35 @@ def load_framework_extension(framework: str) -> None:
spec.loader.exec_module(solib)
def sanity_checks_for_pypi_installation() -> None:
"""Ensure that package is installed correctly if using PyPI."""
te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info()
te_installed = _is_package_installed("transformer_engine")
te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine")
assert te_installed, "Could not find `transformer-engine`."
# If the core package is installed via PyPI.
if te_core_installed:
assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package."
assert version("transformer-engine") == te_core_version, (
"Transformer Engine package version mismatch. Found "
f"transformer-engine v{version('transformer-engine')} "
f"and {te_core_package_name} v{te_core_version}."
)
# Only the metapackage is found, invalid usecase.
elif te_installed_via_pypi:
raise RuntimeError(
"Found empty `transformer-engine` meta package installed. "
"Install `transformer-engine` with framework extensions via"
"'pip3 install --no-build-isolation transformer-engine[pytorch,jax]==VERSION'"
" or 'pip3 install transformer-engine[core]` for the TE core lib only. The `core_cu12`"
" or `core_cu13` extra deps can be used to specify CUDA version for the TE core lib."
)
@functools.lru_cache(maxsize=None)
def _get_sys_extension() -> str:
"""File extension for shared objects."""
......@@ -339,16 +390,14 @@ def _load_core_library():
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
try:
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
_CURAND_LIB_CTYPES = _load_curand()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
# Needed to find the correct headers for NVRTC kernels.
if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir():
os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir()
except OSError:
pass
sanity_checks_for_pypi_installation()
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
_CURAND_LIB_CTYPES = _load_curand()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
_TE_LIB_CTYPES = _load_core_library()
# Needed to find the correct headers for NVRTC kernels.
if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir():
os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir()
......@@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
int64_t window_size_right, bool return_max_logit) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
......@@ -187,7 +187,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
!requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) &&
// 9.10.0: known bugs with SDPA FP8
(cudnn_runtime_version != 91000)) {
(cudnn_runtime_version != 91000) && !return_max_logit) {
if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else {
......@@ -216,7 +216,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) &&
((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) &&
!requires_64bit_ragged_offset &&
(softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) {
(softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && !return_max_logit) {
flag_m512 = true;
}
if (
......@@ -418,8 +418,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
size_t max_seqlen, bool is_training, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout,
size_t max_seqlen, bool is_training, bool return_max_logit,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace,
......@@ -460,7 +460,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -474,10 +474,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, input_Bias,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded,
input_rng_state, wkspace, stream, handle);
b, h, max_seqlen, d, t, is_training, return_max_logit, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV,
input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens,
input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -544,7 +544,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h,
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -602,7 +602,7 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
size_t max_seqlen_kv, bool is_training, bool return_max_logit, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) {
......@@ -680,7 +680,8 @@ void nvte_fused_attn_fwd_kvpacked(
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right,
return_max_logit);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -695,12 +696,12 @@ void nvte_fused_attn_fwd_kvpacked(
#if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left,
window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, output_O,
Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state,
wkspace, stream, handle);
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training,
return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset,
output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k,
input_page_table_v, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -777,7 +778,7 @@ void nvte_fused_attn_bwd_kvpacked(
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -832,18 +833,16 @@ void nvte_fused_attn_bwd_kvpacked(
}
}
// NVTE fused attention FWD with separate Q, K and V
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
void nvte_fused_attn_fwd(
const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -913,7 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right,
return_max_logit);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -928,12 +928,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left,
window_size_right, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O,
Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state,
wkspace, stream, handle);
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training,
return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k,
input_page_table_v, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -1008,7 +1008,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......
......@@ -20,12 +20,13 @@ namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
bool is_training, bool return_max_logit, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens,
const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
......@@ -41,14 +42,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
......@@ -68,14 +70,14 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
......
......@@ -1710,7 +1710,8 @@ void fused_attn_fp8_fwd_impl_v1(
qkv_tensor_type,
o_tensor_type,
cudnn_frontend::DataType_t::NOT_SET,
cudnn_frontend::DataType_t::NOT_SET};
cudnn_frontend::DataType_t::NOT_SET,
false};
namespace fe = cudnn_frontend;
using graph_and_tensors =
......@@ -2038,7 +2039,8 @@ void fused_attn_fp8_bwd_impl_v1(
qkv_tensor_type,
o_tensor_type,
do_tensor_type,
dqkv_tensor_type};
dqkv_tensor_type,
false};
namespace fe = cudnn_frontend;
using graph_and_tensors =
......
......@@ -115,20 +115,21 @@ struct FADescriptor_v1 {
cudnn_frontend::DataType_t o_tensor_type;
cudnn_frontend::DataType_t do_tensor_type;
cudnn_frontend::DataType_t dqkv_tensor_type;
bool generate_max_sum_exp;
bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type,
o_tensor_type, do_tensor_type, dqkv_tensor_type) <
o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,
rhs.dqkv_tensor_type);
rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
}
};
......
......@@ -97,22 +97,23 @@ cutlass::Array<cutlass::float_e2m1_t, 8>
StochasticNumericConverterBase(cutlass::Array<float, 8> const &input, cutlass::Array<uint32_t, 2> const &rbits) {
using result_type = cutlass::Array<cutlass::float_e2m1_t, 8>;
result_type output;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
auto output_ptr = reinterpret_cast<uint16_t *>(&output);
asm volatile( \
"{\n" \
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \
"}" \
: "=h"(output_ptr[0]),
constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
if constexpr (has_rs) {
auto output_ptr = reinterpret_cast<uint16_t *>(&output);
asm volatile( \
"{\n" \
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \
"}" \
: "=h"(output_ptr[0]),
"=h"(output_ptr[1])
: "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]),
: "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]),
"f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]),
"r"(rbits[0]), "r"(rbits[1]));
#else
NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
} else {
NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return output;
}
......
......@@ -190,29 +190,30 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get fused attention backend based on input parameters.
*
* \param[in] is_training Whether the model is in training mode.
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] softmax_type The attention softmax type.
* \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V.
* \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim_qk The head dimension of Q, K.
* \param[in] head_dim_v The head dimension of V.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] is_training Whether the model is in training mode.
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] softmax_type The attention softmax type.
* \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V.
* \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim_qk The head dimension of Q, K.
* \param[in] head_dim_v The head dimension of V.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right);
int64_t window_size_right, bool return_max_logit);
/*! \brief Compute dot product attention with packed QKV input.
*
......@@ -255,6 +256,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(seqlen_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
......@@ -266,13 +268,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen,
bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
size_t max_seqlen, bool is_training, bool return_max_logit,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
......@@ -381,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
......@@ -399,7 +405,7 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
size_t max_seqlen_kv, bool is_training, bool return_max_logit, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
......@@ -520,6 +526,7 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensors' layout.
......@@ -531,18 +538,16 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
void nvte_fused_attn_fwd(
const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
*
......
......@@ -264,48 +264,50 @@ __device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, s
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const uint32_t rbits) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
uint16_t out_4x;
asm volatile(
"{\n"
"cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t"
"}"
: "=h"(out_4x)
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits));
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x);
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
uint16_t dummy = 0;
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
if constexpr (has_rs) {
uint16_t out_4x;
asm volatile(
"{\n"
"cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t"
"}"
: "=h"(out_4x)
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits));
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x);
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt.rs PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
uint16_t dummy = 0;
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
}
}
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01,
const float2 in23,
const uint32_t rbits) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
// NOTE: rbits unused for rn.
uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing.
asm volatile(
"{\n"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x));
return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0];
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
uint16_t dummy = 0;
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr bool has_fp4 = ARCH_BLACKWELL_FAMILY;
if constexpr (has_fp4) {
// NOTE: rbits unused for rn.
uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing.
asm volatile(
"{\n"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x));
return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0];
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
uint16_t dummy = 0;
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
}
}
template <bool kApplyStochasticRounding>
......
......@@ -15,10 +15,9 @@
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#if CUDA_VERSION > 12080
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif // CUDA_VERSION > 12080
#endif // FP4_TYPE_SUPPORTED
#include <cfloat>
#include "../common.h"
......@@ -30,7 +29,7 @@
namespace transformer_engine {
#if CUDA_VERSION > 12080
#if FP4_TYPE_SUPPORTED
namespace nvfp4_transpose {
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
......@@ -152,89 +151,89 @@ __device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int
return rbits;
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(
const uint64_t in_4x, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b16 v0_bf16; \n\t"
".reg.b16 v1_bf16; \n\t"
".reg.b16 v2_bf16; \n\t"
".reg.b16 v3_bf16; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order
"}"
: "=h"(out_4x)
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
if constexpr (has_rs) {
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b16 v0_bf16; \n\t"
".reg.b16 v1_bf16; \n\t"
".reg.b16 v2_bf16; \n\t"
".reg.b16 v3_bf16; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order
"}"
: "=h"(out_4x)
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
}
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x,
const float2 scale,
const uint32_t rbits) {
// NOTE: rbits unused for rn.
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b16 v0_bf16; \n\t"
".reg.b16 v1_bf16; \n\t"
".reg.b16 v2_bf16; \n\t"
".reg.b16 v3_bf16; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
if constexpr (is_blackwell) {
// NOTE: rbits unused for rn.
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b16 v0_bf16; \n\t"
".reg.b16 v1_bf16; \n\t"
".reg.b16 v2_bf16; \n\t"
".reg.b16 v3_bf16; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)));
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
}
......@@ -252,34 +251,35 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
"mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order
"}"
: "=h"(out_4x)
: "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(in23)),
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
if constexpr (has_rs) {
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
"mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order
"}"
: "=h"(out_4x)
: "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(in23)),
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
}
......@@ -287,40 +287,41 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2
const float2 in23,
const float2 scale,
const uint32_t rbits) {
// NOTE: rbits unused for rn.
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(in23)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
if constexpr (is_blackwell) {
// NOTE: rbits unused for rn.
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(in23)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
}
......@@ -335,8 +336,6 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c
}
}
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM)
......@@ -1380,18 +1379,13 @@ __global__ void __launch_bounds__(THREADS_NUM)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace nvfp4_transpose
#endif // CUDA_VERSION > 12080
// Compile-time flag to choose kernel variant
#ifndef USE_2D_NVFP4_KERNEL
#define USE_2D_NVFP4_KERNEL 0
#endif
#endif // FP4_TYPE_SUPPORTED
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
bool use_2d_quantization>
void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
const QuantizationConfig *quant_config, cudaStream_t stream) {
#if CUDA_VERSION > 12080
#if FP4_TYPE_SUPPORTED
bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
......@@ -1509,7 +1503,7 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o
}););
#else
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
#endif // CUDA_VERSION > 12080
#endif // FP4_TYPE_SUPPORTED
}
} // namespace transformer_engine
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment