Commit 0a5016b1 authored by wenjh's avatar wenjh
Browse files

Merge nv release_v2.9


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 063ef88d 70f53666
...@@ -103,8 +103,10 @@ class TestDistributedSoftmax: ...@@ -103,8 +103,10 @@ class TestDistributedSoftmax:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, autocast(mesh_resource=mesh_resource): with mesh, autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) x_named_sharding = NamedSharding(mesh, x_pspec)
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) mask_named_sharding = NamedSharding(mesh, mask_pspec)
x_ = jax.device_put(x, x_named_sharding)
mask_ = jax.device_put(mask, mask_named_sharding)
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
...@@ -116,8 +118,8 @@ class TestDistributedSoftmax: ...@@ -116,8 +118,8 @@ class TestDistributedSoftmax:
grad_args=(0,), grad_args=(0,),
metric_fwd_dtype=dtype, metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype, metric_bwd_dtype=dtype,
in_shardings=(x_pspec, mask_pspec), in_shardings=(x_named_sharding, mask_named_sharding),
out_shardings=(None, (x_pspec,)), out_shardings=(None, x_named_sharding),
) )
except AssertionError as err: except AssertionError as err:
# Softmax should still produce the correct numerical result with # Softmax should still produce the correct numerical result with
......
...@@ -378,14 +378,14 @@ class FusedAttnRunner: ...@@ -378,14 +378,14 @@ class FusedAttnRunner:
pytest.skip( pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
) )
# TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
if ( if (
get_device_compute_capability(0) == 100 get_device_compute_capability(0) >= 100
and self.dropout_prob == 0.1 and self.dropout_prob == 0.1
and self.attn_bias_type is not AttnBiasType.NO_BIAS and self.attn_bias_type is not AttnBiasType.NO_BIAS
): ):
pytest.skip( pytest.skip(
"For sm100, bprop kernel support for dropout + determinism (bias) is not supported" "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
) )
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
......
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
# See LICENSE for license information. # See LICENSE for license information.
import unittest import unittest
from functools import partial
import flax import flax
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from flax import linen as nn
from utils import assert_allclose from utils import assert_allclose
from transformer_engine.common.recipe import ( from transformer_engine.common.recipe import (
...@@ -24,15 +26,51 @@ from transformer_engine.jax.quantize import ( ...@@ -24,15 +26,51 @@ from transformer_engine.jax.quantize import (
ScalingMode, ScalingMode,
update_collections, update_collections,
TensorSource, TensorSource,
QuantizerFactory,
QuantizeLayout,
) )
from transformer_engine.jax.quantize.helper import _format2dtypes from transformer_engine.jax.quantize.helper import _format2dtypes
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
from transformer_engine.jax.flax.module import TransformerEngineBase
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
def quantizer_check_vjp(outer_quantizer_set, assertion_func, x):
"""Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""
# Define a function with a custom VJP (vector-Jacobian product)
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def quantizer_check(inner_quantizer_set, assertion_func, x):
return quantizer_check_fwd(inner_quantizer_set, assertion_func, x)
def quantizer_check_fwd(inner_quantizer_set, assertion_func, x):
assertion_func(inner_quantizer_set.x, TensorSource.X)
assertion_func(inner_quantizer_set.kernel, TensorSource.KERNEL)
assertion_func(inner_quantizer_set.dgrad, TensorSource.DGRAD)
return x
def quantizer_check_bwd(ctx, g):
return (g,)
quantizer_check.defvjp(quantizer_check_fwd, quantizer_check_bwd)
return quantizer_check(outer_quantizer_set, assertion_func, x)
class TestModule(TransformerEngineBase):
"""A simple module to test quantizer creation and reconstruction across VJP boundaries."""
# Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None
assertion_func: callable
@nn.compact
def __call__(self, x):
quantizer_set = self.generate_quantizer_set()
return quantizer_check_vjp(quantizer_set, self.assertion_func, x)
class TestHelper(unittest.TestCase): class TestHelper(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
...@@ -89,12 +127,43 @@ class TestFP8Functions(unittest.TestCase): ...@@ -89,12 +127,43 @@ class TestFP8Functions(unittest.TestCase):
for tensor_source in TensorSource: for tensor_source in TensorSource:
target_scaling_mode = ( target_scaling_mode = (
ScalingMode.NVFP4_2D_SCALING ScalingMode.NVFP4_2D_SCALING
if tensor_source == TensorSource.KERNEL if (not test.disable_2d_quantization) and tensor_source == TensorSource.KERNEL
else ScalingMode.NVFP4_1D_SCALING else ScalingMode.NVFP4_1D_SCALING
) )
self.assertEqual( self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode
) )
self.assertEqual(
get_quantize_config().DISABLE_STOCHASTIC_ROUNDING, test.disable_stochastic_rounding
)
self.assertEqual(get_quantize_config().DISABLE_RHT, test.disable_rht)
self.assertEqual(
get_quantize_config().DISABLE_2D_QUANTIZATION, test.disable_2d_quantization
)
def _compare_nvfp4_scaling_quantizers(self, test):
"""Check that the quantizers created have the expected stochastic rounding state and the state is preserved across VJP boundaries."""
def assertion_func(quantizer, tensor_source):
if test.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD:
self.assertIsNone(quantizer.stochastic_rounding_rng_state)
else:
self.assertIsNotNone(quantizer.stochastic_rounding_rng_state)
expected_rht = (
quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
and quantizer.q_layout in {QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE}
and not test.disable_rht
)
self.assertEqual(quantizer.use_rht, expected_rht)
x = jnp.ones((), dtype=jnp.float32)
test_module = TestModule(assertion_func=assertion_func)
param_key, sr_key = jax.random.split(jax.random.PRNGKey(0))
rngs = {"params": param_key, "sr_rng": sr_key}
variables = test_module.init(rngs, x)
jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs)
@unittest.skipIf(not is_fp8_supported, reason=reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_autocast_delayed_scaling(self): def test_autocast_delayed_scaling(self):
...@@ -171,5 +240,16 @@ class TestFP8Functions(unittest.TestCase): ...@@ -171,5 +240,16 @@ class TestFP8Functions(unittest.TestCase):
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()): with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled()) self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_nvfp4_scaling(bs) self._compare_nvfp4_scaling(bs)
self._compare_nvfp4_scaling_quantizers(bs)
bs = NVFP4BlockScaling(
disable_stochastic_rounding=True,
disable_rht=True,
disable_2d_quantization=True,
)
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_nvfp4_scaling(bs)
self._compare_nvfp4_scaling_quantizers(bs)
self._check_default_state() self._check_default_state()
...@@ -248,6 +248,7 @@ def run_dpa_with_cp( ...@@ -248,6 +248,7 @@ def run_dpa_with_cp(
attn_mask_type=config.attn_mask_type, attn_mask_type=config.attn_mask_type,
window_size=config.window_size, window_size=config.window_size,
softmax_type=config.softmax_type, softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit,
).cuda() ).cuda()
if config.softmax_type != "vanilla": if config.softmax_type != "vanilla":
core_attn.softmax_offset.requires_grad = True core_attn.softmax_offset.requires_grad = True
...@@ -308,6 +309,7 @@ def run_dpa_with_cp( ...@@ -308,6 +309,7 @@ def run_dpa_with_cp(
fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group)
else: else:
fp8_context = nullcontext() fp8_context = nullcontext()
max_logit = None
with fp8_context: with fp8_context:
# q, k, v, out in FP8; dout in F16 # q, k, v, out in FP8; dout in F16
out = core_attn( out = core_attn(
...@@ -322,6 +324,8 @@ def run_dpa_with_cp( ...@@ -322,6 +324,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded=cu_seqlens_kv_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha, fp8_output=fp8_mha,
) )
if config.return_max_logit:
out, max_logit = out
if fp8_bwd and fp8_mha: if fp8_bwd and fp8_mha:
dout_fp8 = dout_quantizer(dout) dout_fp8 = dout_quantizer(dout)
out.backward(dout_fp8) out.backward(dout_fp8)
...@@ -400,6 +404,7 @@ def run_dpa_with_cp( ...@@ -400,6 +404,7 @@ def run_dpa_with_cp(
fp8_context = nullcontext() fp8_context = nullcontext()
# run attention # run attention
max_logit_ = None
with fp8_context: with fp8_context:
# q, k, v, out in FP8; dout in F16 # q, k, v, out in FP8; dout in F16
out_ = core_attn( out_ = core_attn(
...@@ -414,6 +419,8 @@ def run_dpa_with_cp( ...@@ -414,6 +419,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded=cu_seqlens_kv_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha, fp8_output=fp8_mha,
) )
if config.return_max_logit:
out_, max_logit_ = out_
if fp8_bwd and fp8_mha: if fp8_bwd and fp8_mha:
dout_fp8_ = dout_quantizer(dout_) dout_fp8_ = dout_quantizer(dout_)
out_.backward(dout_fp8_) out_.backward(dout_fp8_)
...@@ -495,15 +502,15 @@ def run_dpa_with_cp( ...@@ -495,15 +502,15 @@ def run_dpa_with_cp(
) )
atol, rtol, rmse_tol = get_tols(config, dtype) atol, rtol, rmse_tol = get_tols(config, dtype)
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_] tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_]
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset] tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit]
names = ["out", "dq", "dk", "dv", "d_softmax_offset"] names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"]
names_cp = [x + "_cp" for x in names] names_cp = [x + "_cp" for x in names]
names_no_cp = [x + "_no_cp" for x in names] names_no_cp = [x + "_no_cp" for x in names]
is_fp8 = dtype == "fp8" is_fp8 = dtype == "fp8"
for i, t in enumerate(tensors_no_cp): for i, t in enumerate(tensors_no_cp):
if t is not None: if t is not None:
if "softmax_offset" not in names[i]: if "softmax_offset" not in names[i] and "max_logit" not in names[i]:
if qkv_format == "bshd": if qkv_format == "bshd":
compare_and_assert( compare_and_assert(
t[:, 0], t[:, 0],
......
...@@ -60,8 +60,16 @@ from utils import ( ...@@ -60,8 +60,16 @@ from utils import (
get_available_attention_backends, get_available_attention_backends,
) )
# Check if hardware supports FP8 # Check if hardware supports FP8 attention.
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
fp8_attn_available, reason_for_no_fp8_attn = fp8_available, reason_for_no_fp8
device_compute_capability = get_device_compute_capability()
if fp8_available and (device_compute_capability < (9, 0) or device_compute_capability >= (12, 0)):
fp8_attn_available = False
reason_for_no_fp8_attn = (
"FP8 attention is not supported for compute capability ="
f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}"
)
# Reset RNG seed and states # Reset RNG seed and states
seed = 1234 seed = 1234
...@@ -130,6 +138,11 @@ def test_dot_product_attention( ...@@ -130,6 +138,11 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa: if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2] config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
config.attn_mask_type = (
"padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
)
# Get backends # Get backends
is_training = True is_training = True
...@@ -171,7 +184,7 @@ def test_dot_product_attention( ...@@ -171,7 +184,7 @@ def test_dot_product_attention(
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
if unfused_attn_supported: if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( unfused_attn_fwd, unfused_max_logit, unfused_attn_bwd = _run_dot_product_attention(
dtype, dtype,
config, config,
"UnfusedDotProductAttention", "UnfusedDotProductAttention",
...@@ -185,7 +198,7 @@ def test_dot_product_attention( ...@@ -185,7 +198,7 @@ def test_dot_product_attention(
# FusedAttention backend # FusedAttention backend
if fused_attn_supported: if fused_attn_supported:
if len(fused_attn_backends) == 1: if len(fused_attn_backends) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_max_logit, fused_attn_bwd = _run_dot_product_attention(
dtype, dtype,
config, config,
"FusedAttention", "FusedAttention",
...@@ -197,7 +210,7 @@ def test_dot_product_attention( ...@@ -197,7 +210,7 @@ def test_dot_product_attention(
) )
if len(fused_attn_backends) == 2: if len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention(
dtype, dtype,
config, config,
"FusedAttention", "FusedAttention",
...@@ -208,7 +221,7 @@ def test_dot_product_attention( ...@@ -208,7 +221,7 @@ def test_dot_product_attention(
is_training, is_training,
) )
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention(
dtype, dtype,
config, config,
"FusedAttention", "FusedAttention",
...@@ -221,7 +234,7 @@ def test_dot_product_attention( ...@@ -221,7 +234,7 @@ def test_dot_product_attention(
# FlashAttention backend # FlashAttention backend
if flash_attn_supported: if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( flash_attn_fwd, _, flash_attn_bwd = _run_dot_product_attention(
dtype, dtype,
config, config,
"FlashAttention", "FlashAttention",
...@@ -242,6 +255,8 @@ def test_dot_product_attention( ...@@ -242,6 +255,8 @@ def test_dot_product_attention(
if unfused_attn_supported and fused_attn_supported: if unfused_attn_supported and fused_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs fused attn") logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
if config.return_max_logit:
torch.testing.assert_close(fused_max_logit, unfused_max_logit, **tols)
for i, _ in enumerate(unfused_attn_bwd): for i, _ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported: if fused_attn_supported and flash_attn_supported:
...@@ -265,6 +280,33 @@ def test_dpa_checkpoint(dtype, model_configs, model): ...@@ -265,6 +280,33 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
model_configs_max_logit = {
# test: ModelConfig(b, sq, hq, dqk)
"max_logit_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
"max_logit_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
"max_logit_3": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
"max_logit_4": ModelConfig(
8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias"
),
"max_logit_5": ModelConfig(
8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0)
),
"max_logit_6": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_max_logit])
@pytest.mark.parametrize("model", model_configs_max_logit.keys())
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"])
def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with checkpointing"""
config = model_configs[model]
config.return_max_logit = True
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
model_configs_softmax = { model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk) # test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8), "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
...@@ -962,6 +1004,8 @@ def _run_dot_product_attention( ...@@ -962,6 +1004,8 @@ def _run_dot_product_attention(
layout = layout.replace("d", "dqk") layout = layout.replace("d", "dqk")
tensor_shape = [dim_to_num[j] for j in layout.split("_")] tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda") tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
# tensor: with padding tokens
# tensor_orig: without padding tokens
tensor_orig = tensor tensor_orig = tensor
if qkv_format == "thd" and pad_between_seqs: if qkv_format == "thd" and pad_between_seqs:
tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
...@@ -1071,6 +1115,7 @@ def _run_dot_product_attention( ...@@ -1071,6 +1115,7 @@ def _run_dot_product_attention(
layer_number=1, layer_number=1,
attention_type=config.attn_type, attention_type=config.attn_type,
softmax_type=config.softmax_type, softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit,
).to(dtype=dtype, device="cuda") ).to(dtype=dtype, device="cuda")
if not is_training: if not is_training:
block = block.eval() block = block.eval()
...@@ -1108,16 +1153,21 @@ def _run_dot_product_attention( ...@@ -1108,16 +1153,21 @@ def _run_dot_product_attention(
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
fast_zero_fill=True, fast_zero_fill=True,
) )
max_logit = None
if config.return_max_logit:
out, max_logit = out
if is_training: if is_training:
out.backward(d_out) out.backward(d_out)
d_softmax_offset = None d_softmax_offset = None
if is_training and config.softmax_type != "vanilla": if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad d_softmax_offset = block.softmax_offset.grad
if backend in ["FlashAttention", "UnfusedDotProductAttention"]: if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if is_training: if is_training:
return out, (q.grad, k.grad, v.grad, d_softmax_offset) return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
else: else:
return out, (None, None, None, d_softmax_offset) return out, max_logit, (None, None, None, d_softmax_offset)
if backend == "FusedAttention": if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs: if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
...@@ -1146,14 +1196,18 @@ def _run_dot_product_attention( ...@@ -1146,14 +1196,18 @@ def _run_dot_product_attention(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
) )
if is_training: if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset) return (
out_orig,
max_logit,
(q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset),
)
else: else:
return out_orig, (None, None, None, d_softmax_offset) return out_orig, max_logit, (None, None, None, d_softmax_offset)
else: else:
if is_training: if is_training:
return out, (q.grad, k.grad, v.grad, d_softmax_offset) return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
else: else:
return out, (None, None, None, d_softmax_offset) return out, max_logit, (None, None, None, d_softmax_offset)
model_configs_te_layer = { model_configs_te_layer = {
...@@ -1527,8 +1581,7 @@ model_configs_fp8_extra_state = { ...@@ -1527,8 +1581,7 @@ model_configs_fp8_extra_state = {
} }
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"]) @pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
...@@ -1690,8 +1743,7 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] ...@@ -1690,8 +1743,7 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16) @pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
...@@ -1927,8 +1979,7 @@ def _run_mha_fp8_vs_f16( ...@@ -1927,8 +1979,7 @@ def _run_mha_fp8_vs_f16(
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
...@@ -2256,8 +2307,7 @@ models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"] ...@@ -2256,8 +2307,7 @@ models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
), ),
reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""", reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
) )
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8) @pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0) @pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model): def test_custom_mha_fp8_vs_f16(dtype, model):
......
...@@ -138,8 +138,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -138,8 +138,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn = { model_configs_fused_attn = {
# test: ModelConfig(b, sq, hq, dqk) # test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_logit=True), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_logit=True), # MHA
"cp_1_2": ModelConfig( "cp_1_2": ModelConfig(
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), # MHA ), # MHA
...@@ -184,7 +184,7 @@ dtypes = ["bf16", "fp16", "fp8"] ...@@ -184,7 +184,7 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats = ["bshd", "sbhd", "thd"] qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential: if test_essential:
configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"] dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"] qkv_formats = ["sbhd", "thd"]
......
...@@ -45,11 +45,10 @@ from transformer_engine.pytorch import ( ...@@ -45,11 +45,10 @@ from transformer_engine.pytorch import (
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch import checkpoint as te_checkpoint from transformer_engine.pytorch import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states, get_available_attention_backends from utils import ModelConfig, reset_rng_states
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
...@@ -135,23 +134,6 @@ if torch.cuda.get_device_capability() == (9, 0): ...@@ -135,23 +134,6 @@ if torch.cuda.get_device_capability() == (9, 0):
use_cutlass_grouped_gemm.append(True) use_cutlass_grouped_gemm.append(True)
def is_fused_attn_available(
config: ModelConfig,
dtype: torch.dtype,
qkv_layout="bshd_bshd_bshd",
is_training=True,
deterministic=False,
):
_, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
deterministic=deterministic,
)
return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends
def get_causal_attn_mask(sq: int) -> torch.Tensor: def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
...@@ -872,8 +854,6 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= ...@@ -872,8 +854,6 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
def test_gpt_checkpointing(dtype, bs, model): def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model] config = model_configs[model]
if not is_fused_attn_available(config, dtype, deterministic=True):
pytest.skip("No attention backend available.")
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
...@@ -920,10 +900,6 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): ...@@ -920,10 +900,6 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean) @pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model] config = model_configs[model]
if not is_fused_attn_available(
config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True
):
pytest.skip("No attention backend available.")
te_gpt = TransformerLayer( te_gpt = TransformerLayer(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -1035,10 +1011,6 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): ...@@ -1035,10 +1011,6 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@pytest.mark.parametrize("mask_type", mask_types) @pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type): def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model] config = model_configs[model]
if not is_fused_attn_available(
config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True
):
pytest.skip("No attention backend available.")
te_mha = MultiheadAttention( te_mha = MultiheadAttention(
config.hidden_size, config.hidden_size,
......
...@@ -205,6 +205,7 @@ class ModelConfig: ...@@ -205,6 +205,7 @@ class ModelConfig:
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
context_parallel: bool = False, context_parallel: bool = False,
cp_comm_type: str = "p2p", cp_comm_type: str = "p2p",
return_max_logit=False,
total_requests: int = None, total_requests: int = None,
max_ctx_len: int = None, max_ctx_len: int = None,
num_layers: int = 1, num_layers: int = 1,
...@@ -233,6 +234,7 @@ class ModelConfig: ...@@ -233,6 +234,7 @@ class ModelConfig:
self.window_size = check_set_window_size(self.attn_mask_type, window_size) self.window_size = check_set_window_size(self.attn_mask_type, window_size)
self.context_parallel = context_parallel self.context_parallel = context_parallel
self.cp_comm_type = cp_comm_type self.cp_comm_type = cp_comm_type
self.return_max_logit = return_max_logit
self.total_requests = total_requests self.total_requests = total_requests
self.max_ctx_len = max_ctx_len self.max_ctx_len = max_ctx_len
self.num_layers = num_layers self.num_layers = num_layers
...@@ -318,6 +320,7 @@ def get_available_attention_backends( ...@@ -318,6 +320,7 @@ def get_available_attention_backends(
is_training=is_training, is_training=is_training,
inference_params=inference_params, inference_params=inference_params,
softmax_type=config.softmax_type, softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit,
) )
( (
use_flash_attention, use_flash_attention,
......
...@@ -29,35 +29,80 @@ endif() ...@@ -29,35 +29,80 @@ endif()
# Language options # Language options
if(USE_CUDA) if(USE_CUDA)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
endif()
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON) set(CMAKE_CUDA_STANDARD_REQUIRED ON)
if (CMAKE_BUILD_TYPE STREQUAL "Debug") if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G")
endif() endif()
# Hide non-necessary symbols in shared object. # Hide non-necessary symbols in shared object.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version")
# Transformer Engine library # Transformer Engine library
project(transformer_engine LANGUAGES CUDA CXX) project(transformer_engine LANGUAGES CUDA CXX)
# CUDA Toolkit # CUDA Toolkit
find_package(CUDAToolkit REQUIRED) find_package(CUDAToolkit REQUIRED)
if (CUDAToolkit_VERSION VERSION_LESS 12.0) if (CUDAToolkit_VERSION VERSION_LESS 12.1)
message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}") message(FATAL_ERROR "CUDA 12.1+ is required, but found CUDA ${CUDAToolkit_VERSION}")
endif() endif()
# Process GPU architectures
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
endif()
# Process CMAKE_CUDA_ARCHITECTURES to separate generic and specific architectures
set(NVTE_GENERIC_ARCHS)
set(NVTE_SPECIFIC_ARCHS)
# Check for architecture 100
list(FIND CMAKE_CUDA_ARCHITECTURES "100" arch_100_index)
if(NOT arch_100_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "100")
list(APPEND NVTE_GENERIC_ARCHS "100")
list(APPEND NVTE_SPECIFIC_ARCHS "100a")
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9)
list(APPEND NVTE_SPECIFIC_ARCHS "103a")
endif()
endif()
# Check for architecture 101 (if we see this we are in toolkit <= 12.9)
list(FIND CMAKE_CUDA_ARCHITECTURES "101" arch_101_index)
if(NOT arch_101_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "101")
list(APPEND NVTE_GENERIC_ARCHS "101")
list(APPEND NVTE_SPECIFIC_ARCHS "101a")
endif()
# Check for architecture 110 (if we see this we are in toolkit >= 13.0)
list(FIND CMAKE_CUDA_ARCHITECTURES "110" arch_110_index)
if(NOT arch_110_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "110")
list(APPEND NVTE_GENERIC_ARCHS "110")
list(APPEND NVTE_SPECIFIC_ARCHS "110f")
endif()
# Check for architecture 120
list(FIND CMAKE_CUDA_ARCHITECTURES "120" arch_120_index)
if(NOT arch_120_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "120")
list(APPEND NVTE_GENERIC_ARCHS "120")
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9)
list(APPEND NVTE_SPECIFIC_ARCHS "120f")
else()
list(APPEND NVTE_SPECIFIC_ARCHS "120a")
endif()
endif()
# cuDNN frontend API # cuDNN frontend API
set(CUDNN_FRONTEND_INCLUDE_DIR set(CUDNN_FRONTEND_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include")
...@@ -135,139 +180,206 @@ endif() ...@@ -135,139 +180,206 @@ endif()
# Configure Transformer Engine library # Configure Transformer Engine library
include_directories(${PROJECT_SOURCE_DIR}/..) include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES) set(transformer_engine_SOURCES)
set(transformer_engine_cpp_sources)
set(transformer_engine_cuda_sources)
set(transformer_engine_cuda_arch_specific_sources)
if(USE_CUDA) if(USE_CUDA)
list(APPEND transformer_engine_SOURCES list(APPEND transformer_engine_cpp_sources
cudnn_utils.cpp cudnn_utils.cpp
transformer_engine.cpp transformer_engine.cpp
common.cu fused_attn/fused_attn.cpp
multi_tensor/adam.cu gemm/config.cpp
multi_tensor/compute_scale.cu normalization/common.cpp
multi_tensor/l2norm.cu normalization/layernorm/ln_api.cpp
multi_tensor/scale.cu normalization/rmsnorm/rmsnorm_api.cpp
multi_tensor/sgd.cu util/cuda_driver.cpp
transpose/cast_transpose.cu util/cuda_nvml.cpp
transpose/transpose.cu util/cuda_runtime.cpp
transpose/cast_transpose_fusion.cu util/multi_stream.cpp
transpose/transpose_fusion.cu util/rtc.cpp
transpose/multi_cast_transpose.cu comm_gemm_overlap/userbuffers/ipcsocket.cc
transpose/quantize_transpose_square_blockwise.cu comm_gemm_overlap/userbuffers/userbuffers-host.cpp
transpose/quantize_transpose_vector_blockwise.cu comm_gemm_overlap/comm_gemm_overlap.cpp)
transpose/swap_first_dims.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu list(APPEND transformer_engine_cuda_sources
activation/gelu.cu common.cu
dropout/dropout.cu multi_tensor/adam.cu
fused_attn/flash_attn.cu multi_tensor/compute_scale.cu
fused_attn/context_parallel.cu multi_tensor/l2norm.cu
fused_attn/kv_cache.cu multi_tensor/scale.cu
fused_attn/fused_attn_f16_max512_seqlen.cu multi_tensor/sgd.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu transpose/cast_transpose.cu
activation/relu.cu transpose/transpose.cu
activation/swiglu.cu transpose/cast_transpose_fusion.cu
fused_attn/fused_attn_fp8.cu transpose/transpose_fusion.cu
fused_attn/fused_attn.cpp transpose/multi_cast_transpose.cu
fused_attn/utils.cu transpose/quantize_transpose_vector_blockwise.cu
gemm/config.cpp transpose/swap_first_dims.cu
gemm/cublaslt_gemm.cu dropout/dropout.cu
gemm/cutlass_grouped_gemm.cu fused_attn/flash_attn.cu
normalization/common.cpp fused_attn/context_parallel.cu
normalization/layernorm/ln_api.cpp fused_attn/kv_cache.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu fused_attn/fused_attn_f16_max512_seqlen.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu
normalization/rmsnorm/rmsnorm_api.cpp fused_attn/fused_attn_fp8.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu fused_attn/utils.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu gemm/cublaslt_gemm.cu
permutation/permutation.cu normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
util/cast.cu normalization/layernorm/ln_fwd_cuda_kernel.cu
util/padding.cu normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
util/cuda_driver.cpp normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
util/cuda_nvml.cpp permutation/permutation.cu
util/cuda_runtime.cpp util/padding.cu
util/multi_stream.cpp swizzle/swizzle.cu
util/rtc.cpp swizzle/swizzle_block_scaling.cu
swizzle/swizzle.cu fused_softmax/scaled_masked_softmax.cu
swizzle/swizzle_block_scaling.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_rope/fused_rope.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_router/fused_moe_aux_loss.cu
fused_rope/fused_rope.cu fused_router/fused_score_for_moe_aux_loss.cu
fused_router/fused_moe_aux_loss.cu fused_router/fused_topk_with_score_function.cu
fused_router/fused_score_for_moe_aux_loss.cu recipe/current_scaling.cu
fused_router/fused_topk_with_score_function.cu recipe/delayed_scaling.cu
recipe/current_scaling.cu recipe/fp8_block_scaling.cu
recipe/delayed_scaling.cu recipe/nvfp4.cu
recipe/fp8_block_scaling.cu comm_gemm_overlap/userbuffers/userbuffers.cu)
recipe/nvfp4.cu
hadamard_transform/hadamard_transform.cu list(APPEND transformer_engine_cuda_arch_specific_sources
hadamard_transform/hadamard_transform_cast_fusion.cu gemm/cutlass_grouped_gemm.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc util/cast.cu
comm_gemm_overlap/userbuffers/userbuffers-host.cpp activation/gelu.cu
comm_gemm_overlap/userbuffers/userbuffers.cu activation/relu.cu
comm_gemm_overlap/comm_gemm_overlap.cpp) activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu)
# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
list(APPEND transformer_engine_SOURCES ${transformer_engine_cuda_arch_specific_sources}
${transformer_engine_cuda_sources}
${transformer_engine_cpp_sources})
# Set compile options for CUDA sources with generic architectures
foreach(cuda_source IN LISTS transformer_engine_cuda_sources)
set(arch_compile_options)
foreach(arch IN LISTS NVTE_GENERIC_ARCHS)
list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}")
endforeach()
if(arch_compile_options)
set_property(
SOURCE ${cuda_source}
APPEND
PROPERTY
COMPILE_OPTIONS ${arch_compile_options}
)
endif()
endforeach()
# Set compile options for CUDA sources with specific architectures
foreach(cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources)
set(arch_compile_options)
foreach(arch IN LISTS NVTE_SPECIFIC_ARCHS)
list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}")
endforeach()
if(arch_compile_options)
set_property(
SOURCE ${cuda_source}
APPEND
PROPERTY
COMPILE_OPTIONS ${arch_compile_options}
)
endif()
endforeach()
if (NVTE_WITH_CUBLASMP) if (NVTE_WITH_CUBLASMP)
list(APPEND transformer_engine_SOURCES list(APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp) comm_gemm/comm_gemm.cpp)
endif() endif()
add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
else() else()
list(APPEND transformer_engine_SOURCES list(APPEND transformer_engine_cpp_sources
cudnn_utils.cpp cudnn_utils.cpp
transformer_engine.cpp transformer_engine.cpp
common.cu gemm/config.cpp
fused_attn/flash_attn.cu normalization/common.cpp
fused_attn/context_parallel.cu normalization/layernorm/ln_api.cpp
fused_attn/kv_cache.cu normalization/rmsnorm/rmsnorm_api.cpp
multi_tensor/adam.cu util/cuda_driver.cpp
multi_tensor/compute_scale.cu util/cuda_nvml.cpp
multi_tensor/l2norm.cu util/cuda_runtime.cpp
multi_tensor/scale.cu util/multi_stream.cpp
multi_tensor/sgd.cu util/rtc.cpp
transpose/cast_transpose.cu comm_gemm_overlap/userbuffers/ipcsocket.cc
transpose/transpose.cu comm_gemm_overlap/userbuffers/userbuffers-host.cpp
transpose/cast_transpose_fusion.cu comm_gemm_overlap/comm_gemm_overlap.cpp)
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu list(APPEND transformer_engine_cuda_sources
transpose/quantize_transpose_square_blockwise.cu common.cu
transpose/quantize_transpose_vector_blockwise.cu multi_tensor/adam.cu
transpose/swap_first_dims.cu multi_tensor/compute_scale.cu
activation/gelu.cu multi_tensor/l2norm.cu
dropout/dropout.cu multi_tensor/scale.cu
activation/relu.cu multi_tensor/sgd.cu
activation/swiglu.cu transpose/cast_transpose.cu
gemm/config.cpp transpose/transpose.cu
gemm/cublaslt_gemm.cu transpose/cast_transpose_fusion.cu
gemm/hipblas_gemm.cu transpose/transpose_fusion.cu
normalization/common.cpp transpose/multi_cast_transpose.cu
normalization/layernorm/ln_api.cpp transpose/quantize_transpose_vector_blockwise.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu transpose/swap_first_dims.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu dropout/dropout.cu
normalization/rmsnorm/rmsnorm_api.cpp fused_attn/flash_attn.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu fused_attn/context_parallel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu fused_attn/kv_cache.cu
permutation/permutation.cu fused_attn/fused_attn_f16_max512_seqlen.cu
util/cast.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu
util/padding.cu fused_attn/fused_attn_fp8.cu
util/cuda_driver.cpp fused_attn/utils.cu
util/cuda_nvml.cpp gemm/cublaslt_gemm.cu
util/cuda_runtime.cpp normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
util/multi_stream.cpp normalization/layernorm/ln_fwd_cuda_kernel.cu
util/rtc.cpp normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
swizzle/swizzle.cu normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
swizzle/swizzle_block_scaling.cu permutation/permutation.cu
fused_softmax/scaled_masked_softmax.cu util/padding.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu swizzle/swizzle.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu swizzle/swizzle_block_scaling.cu
fused_rope/fused_rope.cu fused_softmax/scaled_masked_softmax.cu
fused_router/fused_moe_aux_loss.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_router/fused_score_for_moe_aux_loss.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_router/fused_topk_with_score_function.cu fused_rope/fused_rope.cu
recipe/current_scaling.cu fused_router/fused_moe_aux_loss.cu
recipe/delayed_scaling.cu fused_router/fused_score_for_moe_aux_loss.cu
recipe/fp8_block_scaling.cu fused_router/fused_topk_with_score_function.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc recipe/current_scaling.cu
comm_gemm_overlap/userbuffers/userbuffers-host.cpp recipe/delayed_scaling.cu
comm_gemm_overlap/userbuffers/userbuffers.cu recipe/fp8_block_scaling.cu
comm_gemm_overlap/comm_gemm_overlap.cpp) recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers.cu)
list(APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
util/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu)
# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
list(APPEND transformer_engine_SOURCES ${transformer_engine_cuda_arch_specific_sources}
${transformer_engine_cuda_sources}
${transformer_engine_cpp_sources})
if (NVTE_WITH_CUBLASMP) if (NVTE_WITH_CUBLASMP)
list(APPEND transformer_engine_SOURCES list(APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp) comm_gemm/comm_gemm.cpp)
...@@ -316,10 +428,12 @@ if (USE_CUDA) ...@@ -316,10 +428,12 @@ if (USE_CUDA)
CUDA::cublas CUDA::cublas
CUDA::cudart CUDA::cudart
CUDNN::cudnn_all) CUDNN::cudnn_all)
target_include_directories(transformer_engine PRIVATE target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE ${MATHDX_INCLUDE_DIR})
target_include_directories(transformer_engine SYSTEM PRIVATE target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
target_include_directories(transformer_engine PRIVATE target_include_directories(transformer_engine PRIVATE
${CUTLASS_INCLUDE_DIR} ${CUTLASS_INCLUDE_DIR}
...@@ -436,30 +550,36 @@ target_include_directories(transformer_engine PRIVATE ...@@ -436,30 +550,36 @@ target_include_directories(transformer_engine PRIVATE
"${CMAKE_CURRENT_BINARY_DIR}/string_headers") "${CMAKE_CURRENT_BINARY_DIR}/string_headers")
# Compiler options # Compiler options
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu set(nvte_sources_with_fast_math)
fused_softmax/scaled_upper_triang_masked_softmax.cu list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
multi_tensor/adam.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
multi_tensor/compute_scale.cu multi_tensor/adam.cu
multi_tensor/l2norm.cu multi_tensor/compute_scale.cu
multi_tensor/scale.cu multi_tensor/l2norm.cu
multi_tensor/sgd.cu multi_tensor/scale.cu
fused_attn/flash_attn.cu multi_tensor/sgd.cu
fused_attn/context_parallel.cu fused_attn/flash_attn.cu
fused_attn/kv_cache.cu fused_attn/context_parallel.cu
PROPERTIES fused_attn/kv_cache.cu)
COMPILE_OPTIONS "--use_fast_math")
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
set_source_files_properties(activation/gelu.cu list(APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu activation/swiglu.cu
util/cast.cu util/cast.cu)
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
endif() endif()
if(USE_CUDA) if(USE_CUDA)
foreach(cuda_source IN LISTS nvte_sources_with_fast_math)
set_property(
SOURCE ${cuda_source}
APPEND
PROPERTY
COMPILE_OPTIONS "--use_fast_math")
endforeach()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
else() else()
......
...@@ -8,22 +8,18 @@ import ctypes ...@@ -8,22 +8,18 @@ import ctypes
import functools import functools
import glob import glob
import importlib import importlib
from importlib.metadata import version, metadata, PackageNotFoundError from importlib.metadata import version, distribution, PackageNotFoundError
import logging
import os import os
from pathlib import Path from pathlib import Path
import platform import platform
import subprocess import subprocess
import sys import sys
import sysconfig import sysconfig
from typing import Optional from typing import Optional, Tuple
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
_logger = logging.getLogger(__name__)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _is_pip_package_installed(package) -> bool: def _is_package_installed(package) -> bool:
"""Check if the given package is installed via pip.""" """Check if the given package is installed via pip."""
# This is needed because we only want to return true # This is needed because we only want to return true
...@@ -31,12 +27,34 @@ def _is_pip_package_installed(package) -> bool: ...@@ -31,12 +27,34 @@ def _is_pip_package_installed(package) -> bool:
# if it's importable in the current directory due to # if it's importable in the current directory due to
# the presence of the shared library module. # the presence of the shared library module.
try: try:
metadata(package) distribution(package)
except PackageNotFoundError: except PackageNotFoundError:
return False return False
return True return True
@functools.lru_cache(maxsize=None)
def _is_package_installed_from_wheel(package) -> bool:
"""Check if the given package is installed via PyPI."""
if not _is_package_installed(package):
return False
te_dist = distribution(package)
te_wheel_file = ""
for file_path in te_dist.files:
if file_path.name == "WHEEL":
te_wheel_file = te_dist.locate_file("") / file_path
if not te_wheel_file:
return False
with te_wheel_file.open("r") as f:
for line in f:
if line.startswith("Root-Is-Purelib:"):
return line.strip().split(":")[1].strip().lower() == "true"
return False
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _find_shared_object_in_te_dir(te_path: Path, prefix: str) -> Optional[Path]: def _find_shared_object_in_te_dir(te_path: Path, prefix: str) -> Optional[Path]:
""" """
...@@ -112,6 +130,19 @@ def _get_shared_object_file(library: str) -> Path: ...@@ -112,6 +130,19 @@ def _get_shared_object_file(library: str) -> Path:
) )
def get_te_core_package_info() -> Tuple[bool, str, str]:
"""
Check if Tranformer Engine core package is installed.
Returns the module name and version if found.
"""
te_core_packages = ("transformer-engine-cu12", "transformer-engine-cu13")
for package in te_core_packages:
if _is_package_installed(package):
return True, package, version(package)
return False, "", ""
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def load_framework_extension(framework: str) -> None: def load_framework_extension(framework: str) -> None:
""" """
...@@ -130,39 +161,30 @@ def load_framework_extension(framework: str) -> None: ...@@ -130,39 +161,30 @@ def load_framework_extension(framework: str) -> None:
if framework == "torch": if framework == "torch":
extra_dep_name = "pytorch" extra_dep_name = "pytorch"
# Find the TE packages. The core and framework packages can only be installed via PyPI.
# For the `transformer-engine` package, we need to check explicity.
te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info()
te_framework_installed = _is_package_installed(module_name)
te_installed = _is_package_installed("transformer_engine")
te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine")
assert te_installed, "Could not find `transformer_engine`."
# If the framework extension pip package is installed, it means that TE is installed via # If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework # PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version. # extension are all installed via PyPI and have matching versions.
if _is_pip_package_installed(module_name): if te_framework_installed:
assert _is_pip_package_installed( assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package."
"transformer_engine" assert te_core_installed, "Could not find TE core package `transformer-engine-cu*`."
), "Could not find `transformer-engine`."
assert _is_pip_package_installed( assert version(module_name) == version("transformer-engine") == te_core_version, (
"transformer_engine_cu12" "Transformer Engine package version mismatch. Found"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine" f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12" f" v{version('transformer-engine')}, and {te_core_package_name}"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using " f" v{te_core_version}. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'" f"'pip3 install --no-build-isolation transformer-engine[{extra_dep_name}]==VERSION'"
) )
# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
# Note: Should we error? This is a rare use case.
if _is_pip_package_installed("transformer-engine-cu12"):
if not _is_pip_package_installed(module_name):
_logger.info(
"Could not find package %s. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'",
module_name,
)
# After all checks are completed, load the shared object file. # After all checks are completed, load the shared object file.
spec = importlib.util.spec_from_file_location(module_name, _get_shared_object_file(framework)) spec = importlib.util.spec_from_file_location(module_name, _get_shared_object_file(framework))
solib = importlib.util.module_from_spec(spec) solib = importlib.util.module_from_spec(spec)
...@@ -170,6 +192,35 @@ def load_framework_extension(framework: str) -> None: ...@@ -170,6 +192,35 @@ def load_framework_extension(framework: str) -> None:
spec.loader.exec_module(solib) spec.loader.exec_module(solib)
def sanity_checks_for_pypi_installation() -> None:
"""Ensure that package is installed correctly if using PyPI."""
te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info()
te_installed = _is_package_installed("transformer_engine")
te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine")
assert te_installed, "Could not find `transformer-engine`."
# If the core package is installed via PyPI.
if te_core_installed:
assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package."
assert version("transformer-engine") == te_core_version, (
"Transformer Engine package version mismatch. Found "
f"transformer-engine v{version('transformer-engine')} "
f"and {te_core_package_name} v{te_core_version}."
)
# Only the metapackage is found, invalid usecase.
elif te_installed_via_pypi:
raise RuntimeError(
"Found empty `transformer-engine` meta package installed. "
"Install `transformer-engine` with framework extensions via"
"'pip3 install --no-build-isolation transformer-engine[pytorch,jax]==VERSION'"
" or 'pip3 install transformer-engine[core]` for the TE core lib only. The `core_cu12`"
" or `core_cu13` extra deps can be used to specify CUDA version for the TE core lib."
)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _get_sys_extension() -> str: def _get_sys_extension() -> str:
"""File extension for shared objects.""" """File extension for shared objects."""
...@@ -339,16 +390,14 @@ def _load_core_library(): ...@@ -339,16 +390,14 @@ def _load_core_library():
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
try: sanity_checks_for_pypi_installation()
_CUDNN_LIB_CTYPES = _load_cudnn() _CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc() _NVRTC_LIB_CTYPES = _load_nvrtc()
_CURAND_LIB_CTYPES = _load_curand() _CURAND_LIB_CTYPES = _load_curand()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas") _CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime") _CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
# Needed to find the correct headers for NVRTC kernels.
if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir():
os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir()
except OSError:
pass
_TE_LIB_CTYPES = _load_core_library() _TE_LIB_CTYPES = _load_core_library()
# Needed to find the correct headers for NVRTC kernels.
if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir():
os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir()
...@@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) { int64_t window_size_right, bool return_max_logit) {
using namespace transformer_engine; using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device(); const int device_id = cuda::current_device();
...@@ -187,7 +187,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -187,7 +187,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
!requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) &&
// 9.10.0: known bugs with SDPA FP8 // 9.10.0: known bugs with SDPA FP8
(cudnn_runtime_version != 91000)) { (cudnn_runtime_version != 91000) && !return_max_logit) {
if (cudnn_runtime_version >= 8900) { if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8; backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else { } else {
...@@ -216,7 +216,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -216,7 +216,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) &&
((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) &&
!requires_64bit_ragged_offset && !requires_64bit_ragged_offset &&
(softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) { (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && !return_max_logit) {
flag_m512 = true; flag_m512 = true;
} }
if ( if (
...@@ -418,8 +418,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -418,8 +418,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
size_t max_seqlen, bool is_training, float attn_scale, size_t max_seqlen, bool is_training, bool return_max_logit,
float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, int64_t window_size_right, NVTETensor workspace,
...@@ -460,7 +460,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -460,7 +460,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -474,10 +474,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -474,10 +474,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, b, h, max_seqlen, d, t, is_training, return_max_logit, attn_scale, dropout, qkv_layout,
attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, input_Bias, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens,
input_rng_state, wkspace, stream, handle); input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -544,7 +544,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -544,7 +544,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h,
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -602,7 +602,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -602,7 +602,7 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, size_t max_seqlen_kv, bool is_training, bool return_max_logit, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) { NVTETensor workspace, cudaStream_t stream) {
...@@ -680,7 +680,8 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -680,7 +680,8 @@ void nvte_fused_attn_fwd_kvpacked(
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right,
return_max_logit);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -695,12 +696,12 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -695,12 +696,12 @@ void nvte_fused_attn_fwd_kvpacked(
#if (CUDNN_VERSION >= 8903) #if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training,
dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, output_O, window_size_left, window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset,
Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k,
wkspace, stream, handle); input_page_table_v, input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -777,7 +778,7 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -777,7 +778,7 @@ void nvte_fused_attn_bwd_kvpacked(
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -832,18 +833,16 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -832,18 +833,16 @@ void nvte_fused_attn_bwd_kvpacked(
} }
} }
// NVTE fused attention FWD with separate Q, K and V // NVTE fused attention FWD with separate Q, K and V
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_fwd(
const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd); NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
...@@ -913,7 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -913,7 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right,
return_max_logit);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -928,12 +928,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -928,12 +928,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training,
dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_right, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O, window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias,
Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k,
wkspace, stream, handle); input_page_table_v, input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -1008,7 +1008,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -1008,7 +1008,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
......
...@@ -20,12 +20,13 @@ namespace transformer_engine { ...@@ -20,12 +20,13 @@ namespace transformer_engine {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked( void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, bool is_training, bool return_max_logit, float attn_scale, float p_dropout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
...@@ -41,14 +42,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -41,14 +42,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
...@@ -68,14 +70,14 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -68,14 +70,14 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd( void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
......
...@@ -1710,7 +1710,8 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1710,7 +1710,8 @@ void fused_attn_fp8_fwd_impl_v1(
qkv_tensor_type, qkv_tensor_type,
o_tensor_type, o_tensor_type,
cudnn_frontend::DataType_t::NOT_SET, cudnn_frontend::DataType_t::NOT_SET,
cudnn_frontend::DataType_t::NOT_SET}; cudnn_frontend::DataType_t::NOT_SET,
false};
namespace fe = cudnn_frontend; namespace fe = cudnn_frontend;
using graph_and_tensors = using graph_and_tensors =
...@@ -2038,7 +2039,8 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -2038,7 +2039,8 @@ void fused_attn_fp8_bwd_impl_v1(
qkv_tensor_type, qkv_tensor_type,
o_tensor_type, o_tensor_type,
do_tensor_type, do_tensor_type,
dqkv_tensor_type}; dqkv_tensor_type,
false};
namespace fe = cudnn_frontend; namespace fe = cudnn_frontend;
using graph_and_tensors = using graph_and_tensors =
......
...@@ -115,20 +115,21 @@ struct FADescriptor_v1 { ...@@ -115,20 +115,21 @@ struct FADescriptor_v1 {
cudnn_frontend::DataType_t o_tensor_type; cudnn_frontend::DataType_t o_tensor_type;
cudnn_frontend::DataType_t do_tensor_type; cudnn_frontend::DataType_t do_tensor_type;
cudnn_frontend::DataType_t dqkv_tensor_type; cudnn_frontend::DataType_t dqkv_tensor_type;
bool generate_max_sum_exp;
bool operator<(const FADescriptor_v1 &rhs) const { bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type,
o_tensor_type, do_tensor_type, dqkv_tensor_type) < o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,
rhs.dqkv_tensor_type); rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
} }
}; };
......
...@@ -97,22 +97,23 @@ cutlass::Array<cutlass::float_e2m1_t, 8> ...@@ -97,22 +97,23 @@ cutlass::Array<cutlass::float_e2m1_t, 8>
StochasticNumericConverterBase(cutlass::Array<float, 8> const &input, cutlass::Array<uint32_t, 2> const &rbits) { StochasticNumericConverterBase(cutlass::Array<float, 8> const &input, cutlass::Array<uint32_t, 2> const &rbits) {
using result_type = cutlass::Array<cutlass::float_e2m1_t, 8>; using result_type = cutlass::Array<cutlass::float_e2m1_t, 8>;
result_type output; result_type output;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
auto output_ptr = reinterpret_cast<uint16_t *>(&output); if constexpr (has_rs) {
asm volatile( \ auto output_ptr = reinterpret_cast<uint16_t *>(&output);
"{\n" \ asm volatile( \
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \ "{\n" \
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \ "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \
"}" \ "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \
: "=h"(output_ptr[0]), "}" \
: "=h"(output_ptr[0]),
"=h"(output_ptr[1]) "=h"(output_ptr[1])
: "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]),
"f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]), "f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]),
"r"(rbits[0]), "r"(rbits[1])); "r"(rbits[0]), "r"(rbits[1]));
#else } else {
NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. " NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."); "Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL }
return output; return output;
} }
......
...@@ -190,29 +190,30 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); ...@@ -190,29 +190,30 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get fused attention backend based on input parameters. /*! \brief Get fused attention backend based on input parameters.
* *
* \param[in] is_training Whether the model is in training mode. * \param[in] is_training Whether the model is in training mode.
* \param[in] q_dtype The data type of Tensor Q. * \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V. * \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V. * \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type. * \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type. * \param[in] attn_mask_type The attention mask type.
* \param[in] softmax_type The attention softmax type. * \param[in] softmax_type The attention softmax type.
* \param[in] dropout The dropout probability. * \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q. * \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V. * \param[in] num_gqa_groups The number of heads in K, V.
* \param[in] max_seqlen_q The sequence length of Q. * \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V. * \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim_qk The head dimension of Q, K. * \param[in] head_dim_qk The head dimension of Q, K.
* \param[in] head_dim_v The head dimension of V. * \param[in] head_dim_v The head dimension of V.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
*/ */
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right); int64_t window_size_right, bool return_max_logit);
/*! \brief Compute dot product attention with packed QKV input. /*! \brief Compute dot product attention with packed QKV input.
* *
...@@ -255,6 +256,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -255,6 +256,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] max_seqlen Max sequence length used for computing, * \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(seqlen_i) for i=0,...batch_size-1. * it may be >= max(seqlen_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference. * \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
...@@ -266,13 +268,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -266,13 +268,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd_qkvpacked( void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, size_t max_seqlen, bool is_training, bool return_max_logit,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input. /*! \brief Compute the backward of the dot product attention with packed QKV input.
* *
...@@ -381,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -381,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] max_seqlen_kv Max sequence length used for computing for KV. * \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference. * \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
...@@ -399,7 +405,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -399,7 +405,7 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, size_t max_seqlen_kv, bool is_training, bool return_max_logit, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream); NVTETensor workspace, cudaStream_t stream);
...@@ -520,6 +526,7 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -520,6 +526,7 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V. * \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference. * \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensors' layout. * \param[in] qkv_layout QKV tensors' layout.
...@@ -531,18 +538,16 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -531,18 +538,16 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_fwd(
const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V. /*! \brief Compute the backward of the dot product attention with separate Q, K and V.
* *
......
...@@ -264,48 +264,50 @@ __device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, s ...@@ -264,48 +264,50 @@ __device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, s
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding( __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const uint32_t rbits) { const float2 in01, const float2 in23, const uint32_t rbits) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
uint16_t out_4x; if constexpr (has_rs) {
asm volatile( uint16_t out_4x;
"{\n" asm volatile(
"cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t" "{\n"
"}" "cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t"
: "=h"(out_4x) "}"
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits)); : "=h"(out_4x)
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x); : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits));
#else return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x);
NVTE_DEVICE_ERROR( } else {
"FP4 cvt PTX instructions are architecture-specific. " NVTE_DEVICE_ERROR(
"Try recompiling with sm_XXXa instead of sm_XXX."); "FP4 cvt.rs PTX instructions are architecture-specific. "
uint16_t dummy = 0; "Try recompiling with sm_XXXa instead of sm_XXX.");
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); uint16_t dummy = 0;
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
}
} }
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01, __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01,
const float2 in23, const float2 in23,
const uint32_t rbits) { const uint32_t rbits) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL constexpr bool has_fp4 = ARCH_BLACKWELL_FAMILY;
// NOTE: rbits unused for rn. if constexpr (has_fp4) {
uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing. // NOTE: rbits unused for rn.
asm volatile( uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing.
"{\n" asm volatile(
".reg.b8 f0; \n\t" "{\n"
".reg.b8 f1; \n\t" ".reg.b8 f0; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t" ".reg.b8 f1; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t" "cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t" "cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t"
"}" "mov.b32 %0, {f0, f1, f0, f1};\n\t"
: "=r"(out_4x) "}"
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x)); : "=r"(out_4x)
return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0]; : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x));
#else return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0];
NVTE_DEVICE_ERROR( } else {
"FP4 cvt PTX instructions are architecture-specific. " NVTE_DEVICE_ERROR(
"Try recompiling with sm_XXXa instead of sm_XXX."); "FP4 cvt PTX instructions are architecture-specific. "
uint16_t dummy = 0; "Try recompiling with sm_XXXa instead of sm_XXX.");
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); uint16_t dummy = 0;
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
}
} }
template <bool kApplyStochasticRounding> template <bool kApplyStochasticRounding>
......
...@@ -15,10 +15,9 @@ ...@@ -15,10 +15,9 @@
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#if CUDA_VERSION > 12080 #if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h> #include <cuda_fp4.h>
#endif // CUDA_VERSION > 12080 #endif // FP4_TYPE_SUPPORTED
#include <cfloat> #include <cfloat>
#include "../common.h" #include "../common.h"
...@@ -30,7 +29,7 @@ ...@@ -30,7 +29,7 @@
namespace transformer_engine { namespace transformer_engine {
#if CUDA_VERSION > 12080 #if FP4_TYPE_SUPPORTED
namespace nvfp4_transpose { namespace nvfp4_transpose {
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() + using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
...@@ -152,89 +151,89 @@ __device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int ...@@ -152,89 +151,89 @@ __device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int
return rbits; return rbits;
} }
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding( __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(
const uint64_t in_4x, const float2 scale, const uint32_t rbits) { const uint64_t in_4x, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0; uint16_t out_4x = 0;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
asm volatile( if constexpr (has_rs) {
"{\n" asm volatile(
".reg.b64 v01; \n\t" "{\n"
".reg.b64 v23; \n\t" ".reg.b64 v01; \n\t"
".reg.b16 v0_bf16; \n\t" ".reg.b64 v23; \n\t"
".reg.b16 v1_bf16; \n\t" ".reg.b16 v0_bf16; \n\t"
".reg.b16 v2_bf16; \n\t" ".reg.b16 v1_bf16; \n\t"
".reg.b16 v3_bf16; \n\t" ".reg.b16 v2_bf16; \n\t"
".reg.b32 v0; \n\t" ".reg.b16 v3_bf16; \n\t"
".reg.b32 v1; \n\t" ".reg.b32 v0; \n\t"
".reg.b32 v2; \n\t" ".reg.b32 v1; \n\t"
".reg.b32 v3; \n\t" ".reg.b32 v2; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" ".reg.b32 v3; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t" "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t" "cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t" "cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t" "cvt.f32.bf16 v2, v2_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t" "cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v23, {v2, v3}; \n\t" "mov.b64 v01, {v0, v1}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order "mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t" "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v3, v2}, v23; \n\t" "mov.b64 {v1, v0}, v01; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order "mov.b64 {v3, v2}, v23; \n\t"
"}" "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order
: "=h"(out_4x) "}"
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits)); : "=h"(out_4x)
#else : "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
NVTE_DEVICE_ERROR( } else {
"FP4 cvt PTX instructions are architecture-specific. " NVTE_DEVICE_ERROR(
"Try recompiling with sm_XXXa instead of sm_XXX."); "FP4 cvt PTX instructions are architecture-specific. "
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL "Try recompiling with sm_XXXa instead of sm_XXX.");
}
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x); return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
} }
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x,
const float2 scale, const float2 scale,
const uint32_t rbits) { const uint32_t rbits) {
// NOTE: rbits unused for rn. constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL if constexpr (is_blackwell) {
asm volatile( // NOTE: rbits unused for rn.
"{\n" asm volatile(
".reg.b64 v01; \n\t" "{\n"
".reg.b64 v23; \n\t" ".reg.b64 v01; \n\t"
".reg.b16 v0_bf16; \n\t" ".reg.b64 v23; \n\t"
".reg.b16 v1_bf16; \n\t" ".reg.b16 v0_bf16; \n\t"
".reg.b16 v2_bf16; \n\t" ".reg.b16 v1_bf16; \n\t"
".reg.b16 v3_bf16; \n\t" ".reg.b16 v2_bf16; \n\t"
".reg.b32 v0; \n\t" ".reg.b16 v3_bf16; \n\t"
".reg.b32 v1; \n\t" ".reg.b32 v0; \n\t"
".reg.b32 v2; \n\t" ".reg.b32 v1; \n\t"
".reg.b32 v3; \n\t" ".reg.b32 v2; \n\t"
".reg.b8 f0; \n\t" ".reg.b32 v3; \n\t"
".reg.b8 f1; \n\t" ".reg.b8 f0; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" ".reg.b8 f1; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t" "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t" "cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t" "cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t" "cvt.f32.bf16 v2, v2_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t" "cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v23, {v2, v3}; \n\t" "mov.b64 v01, {v0, v1}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order "mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t" "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v3, v2}, v23; \n\t" "mov.b64 {v1, v0}, v01; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" "mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t" "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"}" "mov.b32 %0, {f0, f1, f0, f1};\n\t"
: "=r"(out_4x) "}"
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale))); : "=r"(out_4x)
#else : "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)));
NVTE_DEVICE_ERROR( } else {
"FP4 cvt PTX instructions are architecture-specific. " NVTE_DEVICE_ERROR(
"Try recompiling with sm_XXXa instead of sm_XXX."); "FP4 cvt PTX instructions are architecture-specific. "
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL "Try recompiling with sm_XXXa instead of sm_XXX.");
}
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0]; return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
} }
...@@ -252,34 +251,35 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x ...@@ -252,34 +251,35 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding( __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) { const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0; uint16_t out_4x = 0;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
asm volatile( if constexpr (has_rs) {
"{\n" asm volatile(
".reg.b64 v01; \n\t" "{\n"
".reg.b64 v23; \n\t" ".reg.b64 v01; \n\t"
".reg.b32 v0; \n\t" ".reg.b64 v23; \n\t"
".reg.b32 v1; \n\t" ".reg.b32 v0; \n\t"
".reg.b32 v2; \n\t" ".reg.b32 v1; \n\t"
".reg.b32 v3; \n\t" ".reg.b32 v2; \n\t"
"mov.b64 {v0, v1} , %1; \n\t" ".reg.b32 v3; \n\t"
"mov.b64 {v2, v3} , %2; \n\t" "mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 v01, {v0, v1}; \n\t" "mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v23, {v2, v3}; \n\t" "mov.b64 v01, {v0, v1}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order "mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t" "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v3, v2}, v23; \n\t" "mov.b64 {v1, v0}, v01; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order "mov.b64 {v3, v2}, v23; \n\t"
"}" "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order
: "=h"(out_4x) "}"
: "l"(reinterpret_cast<const uint64_t &>(in01)), : "=h"(out_4x)
"l"(reinterpret_cast<const uint64_t &>(in23)), : "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits)); "l"(reinterpret_cast<const uint64_t &>(in23)),
#else "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
NVTE_DEVICE_ERROR( } else {
"FP4 cvt PTX instructions are architecture-specific. " NVTE_DEVICE_ERROR(
"Try recompiling with sm_XXXa instead of sm_XXX."); "FP4 cvt PTX instructions are architecture-specific. "
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL "Try recompiling with sm_XXXa instead of sm_XXX.");
}
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x); return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
} }
...@@ -287,40 +287,41 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 ...@@ -287,40 +287,41 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2
const float2 in23, const float2 in23,
const float2 scale, const float2 scale,
const uint32_t rbits) { const uint32_t rbits) {
// NOTE: rbits unused for rn. constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL if constexpr (is_blackwell) {
asm volatile( // NOTE: rbits unused for rn.
"{\n" asm volatile(
".reg.b64 v01; \n\t" "{\n"
".reg.b64 v23; \n\t" ".reg.b64 v01; \n\t"
".reg.b32 v0; \n\t" ".reg.b64 v23; \n\t"
".reg.b32 v1; \n\t" ".reg.b32 v0; \n\t"
".reg.b32 v2; \n\t" ".reg.b32 v1; \n\t"
".reg.b32 v3; \n\t" ".reg.b32 v2; \n\t"
".reg.b8 f0; \n\t" ".reg.b32 v3; \n\t"
".reg.b8 f1; \n\t" ".reg.b8 f0; \n\t"
"mov.b64 {v0, v1} , %1; \n\t" ".reg.b8 f1; \n\t"
"mov.b64 {v2, v3} , %2; \n\t" "mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 v01, {v0, v1}; \n\t" "mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v23, {v2, v3}; \n\t" "mov.b64 v01, {v0, v1}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order "mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t" "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v3, v2}, v23; \n\t" "mov.b64 {v1, v0}, v01; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" "mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t" "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"}" "mov.b32 %0, {f0, f1, f0, f1};\n\t"
: "=r"(out_4x) "}"
: "l"(reinterpret_cast<const uint64_t &>(in01)), : "=r"(out_4x)
"l"(reinterpret_cast<const uint64_t &>(in23)), : "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(scale))); "l"(reinterpret_cast<const uint64_t &>(in23)),
#else "l"(reinterpret_cast<const uint64_t &>(scale)));
NVTE_DEVICE_ERROR( } else {
"FP4 cvt PTX instructions are architecture-specific. " NVTE_DEVICE_ERROR(
"Try recompiling with sm_XXXa instead of sm_XXX."); "FP4 cvt PTX instructions are architecture-specific. "
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL "Try recompiling with sm_XXXa instead of sm_XXX.");
}
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0]; return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
} }
...@@ -335,8 +336,6 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c ...@@ -335,8 +336,6 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c
} }
} }
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &), template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE> typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM) __global__ void __launch_bounds__(THREADS_NUM)
...@@ -1380,18 +1379,13 @@ __global__ void __launch_bounds__(THREADS_NUM) ...@@ -1380,18 +1379,13 @@ __global__ void __launch_bounds__(THREADS_NUM)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
} // namespace nvfp4_transpose } // namespace nvfp4_transpose
#endif // CUDA_VERSION > 12080 #endif // FP4_TYPE_SUPPORTED
// Compile-time flag to choose kernel variant
#ifndef USE_2D_NVFP4_KERNEL
#define USE_2D_NVFP4_KERNEL 0
#endif
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &), template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
bool use_2d_quantization> bool use_2d_quantization>
void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
const QuantizationConfig *quant_config, cudaStream_t stream) { const QuantizationConfig *quant_config, cudaStream_t stream) {
#if CUDA_VERSION > 12080 #if FP4_TYPE_SUPPORTED
bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
...@@ -1509,7 +1503,7 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o ...@@ -1509,7 +1503,7 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o
});); }););
#else #else
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
#endif // CUDA_VERSION > 12080 #endif // FP4_TYPE_SUPPORTED
} }
} // namespace transformer_engine } // namespace transformer_engine
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment