Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
......@@ -4,7 +4,7 @@
import random
import torch
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
from transformer_engine.pytorch import parallel_cross_entropy
from utils import dtype_tols
......
......@@ -8,6 +8,7 @@ import torch
import pytest
from typing import Dict, List
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch import (
moe_permute as te_permute,
......@@ -16,14 +17,12 @@ from transformer_engine.pytorch import (
moe_sort_chunks_by_index as te_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs,
)
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
Float8BlockQuantizer,
MXFP8Quantizer,
)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine_torch as tex
import copy
......@@ -1119,7 +1118,7 @@ def perf_test_cuda_kernel(cuda_kernel_fn):
# TE tensor dtypes
_te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16]
if is_bf16_compatible():
if te.is_bf16_available():
_te_dtypes.append(tex.DType.kBFloat16)
......@@ -1239,10 +1238,10 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype):
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
return_reason=True
)
fp8_recipes = [
recipe.MXFP8BlockScaling(),
......
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
from typing import Iterable, Optional
from typing import Optional
import pytest
import torch
......@@ -11,27 +11,34 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch import (
Float8BlockQuantizer,
MXFP8Quantizer,
Float8Quantizer,
NVFP4Quantizer,
quantized_model_init,
Linear,
LayerNormLinear,
LayerNormMLP,
GroupedLinear,
)
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import (
from transformer_engine.pytorch.quantization import (
FP8GlobalStateManager,
_amax_and_scale_update,
fp8_model_init,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear
from transformer_engine.pytorch.distributed import fp8_autocast
from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling
import transformer_engine_torch as tex
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
return_reason=True
)
fp4_available, reason_for_no_fp4 = te.is_nvfp4_available(return_reason=True)
# FP8 per tensor delayed scaling
......@@ -64,7 +71,7 @@ class TestFP8Recipe:
amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo,
)
with te.fp8_autocast(fp8_recipe=recipe):
with te.autocast(recipe=recipe):
module = te.Linear(16, 16)
y = module(
torch.randn([16, 16], device="cuda"),
......@@ -120,7 +127,7 @@ class TestFP8Recipe:
# ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
# Perform forward, backward, and optimizer steps to update fp8_meta
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
with te.autocast(enabled=True, recipe=recipe):
x = torch.randn([16, 16], device="cuda")
y = module(x, is_first_microbatch=is_first_microbatch)
y.backward(torch.randn_like(y))
......@@ -219,7 +226,7 @@ class TestFP8Recipe:
op.weight.fill_(w_history[-1])
# Forward and backward pass
with te.fp8_autocast(fp8_recipe=recipe):
with te.autocast(recipe=recipe):
y = op(x)
y.backward(dy)
......@@ -301,7 +308,7 @@ class TestFP8Recipe:
scaling_factor_compute_algo = None
if fused_update:
scaling_factor_compute_algo = (
lambda amax, scale, fp8_max, recipe: te.fp8._default_sf_compute(
lambda amax, scale, fp8_max, recipe: te.quantization._default_sf_compute(
amax, scale, fp8_max, recipe.margin
)
)
......@@ -311,7 +318,7 @@ class TestFP8Recipe:
# Setup fp8_meta dictionary
def setup_fp8_meta():
with te.fp8_autocast(fp8_recipe=recipe):
with te.autocast(recipe=recipe):
module = te.Linear(16, 16)
y = module(torch.zeros([16, 16], device="cuda"))
y.backward(torch.zeros_like(y))
......@@ -393,11 +400,11 @@ class TestFP8Recipe:
],
)
def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_recipe):
with fp8_model_init(enabled=True, recipe=model_init_recipe):
with quantized_model_init(enabled=True, recipe=model_init_recipe):
linear = Linear(32, 32).cuda()
x = torch.randn(32, 32, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling()):
with te.autocast(enabled=True, recipe=DelayedScaling()):
with pytest.raises(RuntimeError) as excinfo:
_ = linear(x)
assert "Recipe mismatch for " in str(excinfo.value)
......@@ -436,7 +443,7 @@ class TestFP8Recipe:
# Run initial iterations with DelayedScaling
for _ in range(3):
x = torch.randn(batch_size, in_features, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=initial_recipe):
with te.autocast(enabled=True, recipe=initial_recipe):
y = linear(x)
loss = y.mean()
loss.backward()
......@@ -453,7 +460,7 @@ class TestFP8Recipe:
if i == 0:
# Expect a warning on the first iteration with the new recipe
with pytest.warns(UserWarning, match="Recipe type changed"):
with fp8_autocast(enabled=True, fp8_recipe=target_recipe):
with te.autocast(enabled=True, recipe=target_recipe):
y = linear(x)
for quantizer in linear.quantizers["scaling_fwd"]:
assert isinstance(quantizer, expected_quantizer_type)
......@@ -461,7 +468,7 @@ class TestFP8Recipe:
# No warning expected on subsequent iterations
with warnings.catch_warnings():
warnings.simplefilter("error") # Raise error if unexpected warning occurs
with fp8_autocast(enabled=True, fp8_recipe=target_recipe):
with te.autocast(enabled=True, recipe=target_recipe):
y = linear(x)
loss = y.mean()
loss.backward()
......@@ -485,7 +492,7 @@ class TestFP8Recipe:
batch_size = 32
recipe = DelayedScaling(amax_history_len=1024)
with fp8_model_init(recipe=recipe):
with quantized_model_init(recipe=recipe):
if module_class == GroupedLinear:
module = module_class(1, in_features, out_features).cuda()
else:
......@@ -493,10 +500,43 @@ class TestFP8Recipe:
x = torch.randn(batch_size, in_features, device="cuda")
recipe = DelayedScaling(amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
with te.autocast(enabled=True, recipe=recipe):
warn_msg = "Quantizer is being updated, this may affect model behavior"
with pytest.warns(UserWarning, match=warn_msg):
if module_class == GroupedLinear:
y = module(x, [batch_size])
else:
y = module(x)
@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(304, 304),
(320, 256),
# # largest tile
(8192, 8192),
],
)
def test_fp4_dequantize(dtype, M, N):
q = NVFP4Quantizer()
a = torch.rand((M, N)).cuda().to(dtype=dtype)
starting_tensor = q(a)
dequantized_tensor = starting_tensor.dequantize()
new_tensor = q(dequantized_tensor)
torch.testing.assert_close(
new_tensor._rowwise_data,
starting_tensor._rowwise_data,
rtol=0,
atol=0,
)
new_dequantized_tensor = new_tensor.dequantize()
torch.testing.assert_close(dequantized_tensor, new_dequantized_tensor)
......@@ -9,18 +9,16 @@ import pytest
import os
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine.pytorch
from transformer_engine.pytorch.fp8 import (
fp8_autocast,
FP8GlobalStateManager,
fp8_model_init,
)
import transformer_engine
import transformer_engine.pytorch as te
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
)
from transformer_engine.pytorch import (
autocast,
quantized_model_init,
LayerNormLinear,
Linear,
GroupedLinear,
......@@ -28,26 +26,25 @@ from transformer_engine.pytorch import (
TransformerLayer,
RMSNorm,
LayerNorm,
Float8CurrentScalingQuantizer,
Float8Quantizer,
Float8Tensor,
MXFP8Tensor,
checkpoint,
QuantizedTensor,
is_bf16_available,
)
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import general_gemm
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer,
Float8Quantizer,
Float8Tensor,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint
from utils import ModelConfig
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
# Record initial RNG state from script run.
seed = 1234
......@@ -88,9 +85,19 @@ model_configs = {
"large": ModelConfig(2, 128, 4, 128, num_layers=1),
}
def nvfp4_vanilla():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
return nvfp4_recipe
fp8_recipes = []
if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
fp8_recipes.append(nvfp4_vanilla()) # TODO: fix check for this
if fp8_block_scaling_available:
fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
......@@ -99,7 +106,7 @@ if fp8_available:
fp8_recipes.append(None)
param_types = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
if is_bf16_available(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
all_boolean = [True, False]
......@@ -151,7 +158,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
use_fp8 = fp8_recipe is not None
with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with autocast(enabled=use_fp8, recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()
......@@ -190,7 +197,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
p.main_grad = torch.zeros_like(p)
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with autocast(enabled=use_fp8, recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
......@@ -218,7 +225,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
_disable_wgrads(block)
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with autocast(enabled=use_fp8, recipe=fp8_recipe):
te_out = block(te_inp_hidden_states)
loss = te_out.sum()
loss.backward()
......@@ -244,7 +251,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
_disable_wgrads(block)
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with autocast(enabled=use_fp8, recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
......@@ -276,7 +283,7 @@ def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
_disable_wgrads(block)
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with autocast(enabled=use_fp8, recipe=fp8_recipe):
te_out = block(
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
......@@ -305,7 +312,7 @@ def _test_sanity_common(
_disable_wgrads(block)
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with autocast(enabled=use_fp8, recipe=fp8_recipe):
if not microbatching:
te_out = block(te_inp)
else:
......@@ -386,6 +393,8 @@ def test_sanity_layernorm_linear(
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -414,6 +423,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......@@ -450,9 +461,11 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
pytest.skip(reason_for_no_fp8_block_scaling)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
te_linear = Linear(
config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
).cuda()
......@@ -460,7 +473,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with autocast(enabled=use_fp8, recipe=fp8_recipe):
out = te_linear(inp_hidden_states)
loss = out.sum()
loss.backward()
......@@ -489,9 +502,11 @@ def test_sanity_grouped_linear(
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4():
pytest.skip("NVFP4 not supported for grouped linear")
use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
te_grouped_linear = GroupedLinear(
num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
).cuda()
......@@ -507,7 +522,7 @@ def test_sanity_grouped_linear(
elif empty_split == "middle":
m_splits[num_gemms // 2] = 0
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with autocast(enabled=use_fp8, recipe=fp8_recipe):
out = te_grouped_linear(inp_hidden_states, m_splits)
loss = out.sum()
loss.backward()
......@@ -545,6 +560,8 @@ def test_sanity_layernorm_mlp(
pytest.skip(reason_for_no_fp8_block_scaling)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -593,6 +610,8 @@ def test_sanity_gpt(
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -654,6 +673,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
pytest.skip(reason_for_no_fp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -708,6 +729,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
pytest.skip(reason_for_no_fp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -765,6 +788,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -801,6 +826,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model):
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -841,6 +868,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -881,6 +910,8 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -991,9 +1022,9 @@ def test_replace_raw_data_for_float8tensor():
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_init_high_precision_init_val():
"""Test fp8_model_init with preserve_high_precision_init_val=True"""
with fp8_model_init(preserve_high_precision_init_val=True):
def test_quantized_model_init_high_precision_init_val():
"""Test quantized_model_init with preserve_high_precision_init_val=True"""
with quantized_model_init(preserve_high_precision_init_val=True):
model = Linear(768, 768)
weight = model.weight
......@@ -1066,7 +1097,7 @@ def test_linear_frozen_weights_memory_default_recipe():
linear.weight.requires_grad = False
# Forward and backward pass with FP8
with fp8_autocast():
with autocast():
o = linear(x)
g_o = torch.randn_like(o)
......@@ -1120,7 +1151,7 @@ def test_inference_mode(
# Construct module
module = None
with torch.no_grad():
with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe):
with quantized_model_init(enabled=with_quantization, recipe=quantization_recipe):
if module_name == "Linear":
module = Linear(hidden_size, hidden_size)
elif module_name == "LayerNormLinear":
......@@ -1155,6 +1186,6 @@ def test_inference_mode(
kwargs = {}
if module_name == "GroupedLinear":
kwargs["m_splits"] = [sequence_length]
with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe):
with autocast(enabled=with_quantization, recipe=quantization_recipe):
y = module(x, **kwargs)
check_weights()
......@@ -7,19 +7,20 @@ from __future__ import annotations
import logging
import os
from contextlib import contextmanager
from typing import Optional, Tuple, Dict, Any, List
import pytest
import torch
import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import InferenceParams
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
get_attention_backend,
AttentionParams,
AttentionLogging,
check_set_window_size,
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
......@@ -72,6 +73,8 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
if dtype == tex.DType.kFloat4E2M1:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.25
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
......@@ -94,10 +97,25 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
if dtype == torch.float8_e4m3fn:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == torch.float8_e5m2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
return dict(rtol=0.25, atol=0.125) # epsilon = 0.125
raise ValueError(f"Unsupported dtype ({dtype})")
def quantization_tols(name: str) -> dict[str, float]:
"""Estimated numerical error for a quantization scheme"""
if name in (
"fp8",
"fp8_delayed_scaling",
"fp8_current_scaling",
"mxfp8",
"mxfp8_block_scaling",
):
return dtype_tols(tex.DType.kFloat8E4M3)
if name == "nvfp4":
return dtype_tols(tex.DType.kFloat4E2M1)
raise ValueError(f"Unsupported quantization scheme ({name})")
def make_recipe(name: Optional[str]) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
......@@ -117,6 +135,12 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
)
if name == "fp8_block_scaling":
return transformer_engine.common.recipe.Float8BlockScaling()
if name == "nvfp4":
return transformer_engine.common.recipe.NVFP4BlockScaling(
disable_rht=True,
disable_stochastic_rounding=True,
disable_2d_quantization=True,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
......@@ -137,6 +161,31 @@ def reset_rng_states() -> None:
torch.cuda.set_rng_state(cuda_rng_state)
def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8):
if not is_fp8:
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
return
try:
if a.dtype != b.dtype:
a = a.to(b.dtype)
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
rmse = torch.sqrt((a - b).square().mean()).item()
logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert rmse < rmse_tol * rmse_range, (
name_a
+ " vs "
+ name_b
+ " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
)
class ModelConfig:
def __init__(
self,
......@@ -147,12 +196,15 @@ class ModelConfig:
max_seqlen_kv: int = None,
num_gqa_groups: int = None,
head_dim_v: int = None,
softmax_type: str = "vanilla",
dropout_p: float = 0.0,
attn_mask_type: str = "no_mask",
attn_bias_type: str = "no_bias",
alibi_type: str = "none",
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
context_parallel: bool = False,
cp_comm_type: str = "p2p",
total_requests: int = None,
max_ctx_len: int = None,
num_layers: int = 1,
......@@ -171,13 +223,16 @@ class ModelConfig:
self.kv_channels = (self.head_dim_qk, self.head_dim_v)
self.hidden_size = self.num_heads * self.head_dim_qk
self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v
self.softmax_type = softmax_type
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross"
self.bias_shape = bias_shape
self.window_size = window_size
self.window_size = check_set_window_size(self.attn_mask_type, window_size)
self.context_parallel = context_parallel
self.cp_comm_type = cp_comm_type
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
......@@ -198,9 +253,7 @@ def get_available_attention_backends(
config: ModelConfig,
qkv_dtype: torch.dtype,
qkv_layout: str,
window_size: Tuple[int, int] = (-1, -1),
pad_between_seqs: bool = False,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
......@@ -250,19 +303,21 @@ def get_available_attention_backends(
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
window_size=config.window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
context_parallel=config.context_parallel,
cp_comm_type=config.cp_comm_type,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
is_training=is_training,
inference_params=inference_params,
softmax_type=config.softmax_type,
)
(
use_flash_attention,
......
......@@ -110,6 +110,28 @@ set(CUTLASS_TOOLS_INCLUDE_DIR
# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
# NVIDIA MathDX include directory (from Python package install location)
if(NOT DEFINED MATHDX_INCLUDE_DIR)
execute_process(
COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx
OUTPUT_VARIABLE _PIP_SHOW_MATHDX
ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR
RESULT_VARIABLE _PIP_SHOW_MATHDX_RES
OUTPUT_STRIP_TRAILING_WHITESPACE)
if(NOT _PIP_SHOW_MATHDX_RES EQUAL 0)
message(FATAL_ERROR "Failed to query 'nvidia-mathdx' with pip (using ${Python_EXECUTABLE}): ${_PIP_SHOW_MATHDX_ERR}")
endif()
string(REGEX MATCH "Location: ([^\n\r]+)" _MATHDX_LOC_MATCH "${_PIP_SHOW_MATHDX}")
if(NOT _MATHDX_LOC_MATCH)
message(FATAL_ERROR "Could not parse installation location for 'nvidia-mathdx'. Output was:\n${_PIP_SHOW_MATHDX}")
endif()
set(MATHDX_LOCATION "${CMAKE_MATCH_1}")
set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include")
endif()
if(NOT EXISTS "${MATHDX_INCLUDE_DIR}")
message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.")
endif()
# Configure Transformer Engine library
include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES)
......@@ -132,6 +154,7 @@ if(USE_CUDA)
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
activation/gelu.cu
dropout/dropout.cu
fused_attn/flash_attn.cu
......@@ -144,6 +167,7 @@ if(USE_CUDA)
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/config.cpp
gemm/cublaslt_gemm.cu
gemm/cutlass_grouped_gemm.cu
normalization/common.cpp
......@@ -162,6 +186,7 @@ if(USE_CUDA)
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
......@@ -172,6 +197,9 @@ if(USE_CUDA)
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/nvfp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
......@@ -206,6 +234,7 @@ else()
dropout/dropout.cu
activation/relu.cu
activation/swiglu.cu
gemm/config.cpp
gemm/cublaslt_gemm.cu
gemm/hipblas_gemm.cu
normalization/common.cpp
......@@ -224,6 +253,7 @@ else()
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
......
......@@ -51,22 +51,20 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
}
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
void gated_act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = false;
constexpr NVTETensor grad = nullptr;
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, stream);
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, p, stream);
}
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
ComputeType (*DActOP)(ComputeType, const Param &)>
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p,
cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = true;
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, stream);
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, p, stream);
}
} // namespace transformer_engine
......
......@@ -23,14 +23,16 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
Empty e = {};
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, e, stream);
}
void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(grad, input, output, stream);
Empty e = {};
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(grad, input, output, e, stream);
}
void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
......@@ -49,12 +51,14 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
Empty e = {};
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, e, stream);
}
void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(grad, input, output, stream);
Empty e = {};
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(grad, input, output, e, stream);
}
......@@ -23,14 +23,16 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream);
Empty e = {};
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, e, stream);
}
void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(grad, input, output, stream);
Empty e = {};
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(grad, input, output, e, stream);
}
void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
......@@ -49,12 +51,14 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream);
Empty e = {};
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, e, stream);
}
void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(grad, input, output, stream);
Empty e = {};
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(grad, input, output, e, stream);
}
......@@ -23,12 +23,31 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream);
Empty e = {};
gated_act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, e, stream);
}
void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(grad, input, output, stream);
Empty e = {};
dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(grad, input, output, e, stream);
}
void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_swiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha};
gated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>>(input, output, param, stream);
}
void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_dswiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha};
dgated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>, clamped_dsilu<fp32, fp32>>(
grad, input, output, param, stream);
}
......@@ -79,6 +79,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
#endif
_comm_created = true;
}
initialize(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority,
num_comm_sm, set_sm_margin, use_ce, atomic_gemm);
}
void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams,
int comm_cga_size, int gemm_priority, int comm_priority,
int num_comm_sm, bool set_sm_margin, bool use_ce,
bool atomic_gemm) {
_use_ce = static_cast<int>(use_ce);
_num_comm_sm = num_comm_sm;
_cga_size = comm_cga_size;
......@@ -339,6 +348,11 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
_ub_force_blas_multistream = true;
}
_ub_stream_nums = num_max_streams;
initialize(buffer_shape, buffer_dtype, rs_overlap_first_gemm);
}
void CommOverlapBase::initialize(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
bool rs_overlap_first_gemm) {
_rs_overlap_first_gemm = rs_overlap_first_gemm;
_rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0);
NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3,
......@@ -349,7 +363,9 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype);
void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
if (_ub_comm->myrank == 0) {
printf("!!! [UB] Register UBuf %d\n", _ub_reg);
}
_ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype);
int comm_cu_nums = getIntEnv("NVTE_UB_COMM_CU_NUMS", 8, 4);
......@@ -765,6 +781,11 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
_ub_force_blas_multistream = true;
}
_ub_stream_nums = num_max_streams;
initialize(buffer_shape, buffer_dtype, comm_type, aggregate);
}
void CommOverlapP2PBase::initialize(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
CommOverlapType comm_type, bool aggregate) {
_is_p2p = true;
_is_reduce_scatter = comm_type == CommOverlapType::RS;
_aggregate = aggregate;
......@@ -772,28 +793,28 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
// Create workspace tensor with userbuffer
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype);
int buffer_chunk_bytes = buffer_bytes / tp_size;
_num_ubuf_chunks = tp_size;
int buffer_chunk_bytes = buffer_bytes / _tp_size;
_num_ubuf_chunks = _tp_size;
if (_is_reduce_scatter) {
// GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk
// outputs for reduction at the end of the pipelining.
buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1);
_num_ubuf_chunks = tp_size * 2 - 1;
buffer_bytes = buffer_bytes / _tp_size * (_tp_size * 2 - 1);
_num_ubuf_chunks = _tp_size * 2 - 1;
}
void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
if (_rank == 0) printf("!!! [UBP2P] UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(
buffer_ptr,
std::vector<size_t>{buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]},
std::vector<size_t>{buffer_shape[0] / _tp_size * _num_ubuf_chunks, buffer_shape[1]},
buffer_dtype);
// Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(buffer_ptr);
for (int i = 0; i < _num_ubuf_chunks; i++) {
_ubufs.push_back(TensorWrapper(reinterpret_cast<void *>(ubuf_byte_ptr),
std::vector<size_t>{buffer_shape[0] / tp_size, buffer_shape[1]},
std::vector<size_t>{buffer_shape[0] / _tp_size, buffer_shape[1]},
buffer_dtype));
ubuf_byte_ptr += buffer_chunk_bytes;
}
......@@ -818,7 +839,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
static cudaStream_t send_streams[NVTE_COMM_OVERLAP_MAX_STREAMS];
static cudaStream_t recv_stream;
for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) {
for (int i = 0; i < std::min(_ub_stream_nums, _tp_size); i++) {
if (send_streams[i] == nullptr) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&send_streams[i], cudaStreamNonBlocking, _comm_priority));
}
......@@ -842,6 +863,38 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
}
}
void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source,
bool local_chunk, bool rowwise) {
// Check element size
const size_t element_size = source.element_size();
NVTE_CHECK(_ubuf.element_size() == element_size,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible ",
"(source dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(),
" bytes)");
// Input data
const size_t source_size = source.numel();
const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr();
// Userbuffers data
void *dst_ptr;
if (local_chunk) {
NVTE_CHECK(_ubufs[_tp_id].numel() == source_size,
"Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ",
"(source_size=", source_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")");
dst_ptr = _ubufs[_tp_id].dptr();
} else {
NVTE_CHECK(_ubuf.numel() == source_size,
"Tried to copy an invalid tensor into a Userbuffers buffer ",
"(source_size=", source_size, ", ubuf_size=", _ubuf.numel(), ")");
dst_ptr = _ubuf.dptr();
}
// Copy data
NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size,
cudaMemcpyDeviceToDevice, stream));
}
TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source,
size_t chunk_id) {
// Start with a chunk of the source tensor
......@@ -982,6 +1035,15 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
const bool do_gelu = pre_gelu_out.numel() > 0;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Check B copy sizing
if (B_copy.numel() > 0) {
NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ",
_ubuf.numel(), " elements but got ", B_copy.numel());
NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(),
"Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8,
"-bit data type but got ", B_copy.element_size() * 8, "-bit");
}
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
......@@ -1057,12 +1119,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice,
_stream_send[0]));
}
}
} else {
......@@ -1117,16 +1173,16 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice,
_stream_send[0]));
}
}
}
// Copy all-gathered B from communication buffer into auxiliary output
if (B_copy.numel() > 0) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(),
cudaMemcpyDeviceToDevice, _stream_send[0]));
}
_ub_comm->sms = ori_sms;
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
......
......@@ -679,9 +679,36 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
reinterpret_cast<void *>(&memhndl), sizeof(cudaIpcMemHandle_t),
comm->comm_intra);
// Check for NVLINK support before attempting IPC operations
if (comm->nvsize > 1) {
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
cudaDeviceProp deviceProp;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, current_device));
bool peer_access_available = false;
for (int i = 0; i < comm->nvsize; i++) {
if (i != comm->nvrank) {
NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*)
int can_access_peer;
cudaError_t peer_result = cudaDeviceCanAccessPeer(&can_access_peer, current_device, i);
if (peer_result == cudaSuccess && can_access_peer) {
peer_access_available = true;
break;
}
}
}
if (!peer_access_available) {
free(tmp);
NVTE_ERROR(
"No peer-to-peer access available between GPUs. This platform does not support the "
"GPU-to-GPU "
"communication required for multi-GPU userbuffers. Consider using single-GPU mode.");
return 1;
}
}
for (int i = 0; i < comm->nvsize; i++) {
if (i != comm->nvrank) {
NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i],
cudaIpcMemLazyEnablePeerAccess));
}
}
......@@ -702,4 +729,5 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
comm->mem_ptr[hndl] = *gpubuff;
return comm->free_region++;
printf("***** Returning *****\n");
}
......@@ -39,6 +39,10 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
return CUDA_R_8F_E4M3;
case DType::kFloat8E5M2:
return CUDA_R_8F_E5M2;
#if CUDA_VERSION >= 12080
case DType::kFloat4E2M1:
return CUDA_R_4F_E2M1;
#endif
default:
NVTE_ERROR("Invalid type");
}
......@@ -165,7 +169,9 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_num_bits) {
const uint32_t offset_elems, const size_t type_num_bits,
const CUtensorMapSwizzle swizzle) {
cuda_driver::ensure_context_exists();
// Get a function pointer to the cuTensorMapEncodeTiled driver API
// Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13
static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() {
......@@ -174,6 +180,8 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
}();
// rank is the number of dimensions of the array
constexpr uint32_t rank = 2;
// Dimension for the packed data types must reflect the number of individual U# values.
uint64_t size[rank] = {globalX, globalY};
// The stride is the number of bytes to traverse from the first element of one row to the next
......@@ -212,7 +220,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
// Swizzling can be used to avoid shared memory bank conflicts.
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
swizzle,
// L2 Promotion can be used to widen the effect of a cache-policy to a wider
// set of L2 cache lines.
......
......@@ -54,8 +54,14 @@ inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) {
return mode == NVTE_DELAYED_TENSOR_SCALING;
}
inline bool is_nvfp4_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; }
inline bool is_mxfp8_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }
inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }
inline bool is_nvfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; }
inline size_t product(const std::vector<size_t> &shape, const size_t begin, const size_t end) {
NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ",
end, " in a vector with ", shape.size(), " entries");
......@@ -114,6 +120,7 @@ struct Tensor {
SimpleTensor data;
SimpleTensor columnwise_data;
SimpleTensor amax;
SimpleTensor columnwise_amax;
SimpleTensor scale;
SimpleTensor scale_inv;
SimpleTensor columnwise_scale_inv;
......@@ -125,6 +132,7 @@ struct Tensor {
: data(),
columnwise_data(),
amax(nullptr, {1}, DType::kFloat32),
columnwise_amax(nullptr, {1}, DType::kFloat32),
scale(nullptr, {1}, DType::kFloat32),
scale_inv(nullptr, {1}, DType::kFloat32),
columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
......@@ -135,6 +143,7 @@ struct Tensor {
data.clear();
columnwise_data.clear();
amax.clear();
columnwise_amax.clear();
scale.clear();
scale_inv.clear();
columnwise_scale_inv.clear();
......@@ -180,6 +189,7 @@ struct Tensor {
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569).
*/
switch (scaling_mode) {
case NVTE_NVFP4_1D_SCALING:
case NVTE_DELAYED_TENSOR_SCALING:
if (!has_data() && has_columnwise_data()) {
std::vector<size_t> ret;
......@@ -195,7 +205,6 @@ struct Tensor {
}
break;
case NVTE_MXFP8_1D_SCALING:
case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING:
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape;
} else {
......@@ -267,12 +276,18 @@ struct QuantizationConfig {
NVTETensor noop_tensor = nullptr;
Float8BlockScaleTensorFormat float8_block_scale_tensor_format =
Float8BlockScaleTensorFormat::GEMM_READY;
NVTETensor rng_state = nullptr;
bool nvfp4_2d_quantization = false;
bool stochastic_rounding = false;
static constexpr size_t attr_sizes[] = {
sizeof(bool), // force_pow_2_scales
sizeof(float), // amax_epsilon
sizeof(NVTETensor), // noop_tensor
sizeof(Float8BlockScaleTensorFormat) // float8_block_scale_tensor_format
sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format
sizeof(NVTETensor), // rng_seed and offset
sizeof(bool), // nvfp4_2d_quantization
sizeof(bool) // stochastic_rounding
};
};
......@@ -305,6 +320,8 @@ using fp8e8m0 = __nv_fp8_e8m0;
#endif
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
#endif
using e8m0_t = uint8_t;
......@@ -342,12 +359,14 @@ struct TypeExtrema;
template <>
struct TypeExtrema<fp4e2m1> {
static constexpr float max = 6.0f;
static constexpr float max_inverse = 1.0 / max;
};
#endif
template <>
struct TypeExtrema<fp8e4m3> {
static constexpr float max = 448.0f;
static constexpr float max_inverse = 1.0 / max;
};
template <>
......@@ -358,6 +377,7 @@ struct TypeExtrema<int8> {
template <>
struct TypeExtrema<fp8e5m2> {
static constexpr float max = 57344.0f;
static constexpr float max_inverse = 1.0 / max;
};
template <>
......@@ -602,6 +622,18 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
}
// Add a pack_size argument to select the packed type for FP4
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(dtype, pack_size, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat4E2M1: { \
using type = __nv_fp4x2_storage_t; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
......@@ -812,10 +844,11 @@ void checkCuDriverContext(CUstream stream);
CUtensorMapDataType get_CUtensorMapDataType(DType dtype);
// Set up parameters to create TMA descriptor.
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_num_bits);
void create_2D_tensor_map(
CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY,
const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX,
const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits,
const CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
#endif
bool is_supported_by_CC_100();
......
......@@ -135,9 +135,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
// select a backend for fused attention
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
......@@ -175,7 +176,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// sm90: fwd d<=256, bwd d=128 only
// sm100: fwd d<=128, bwd d<=128
((sm_arch_ < 100 && head_dim_qk <= 256 && head_dim_v <= 256) ||
((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) ||
(sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) ||
(sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) &&
head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
......@@ -183,7 +185,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
!requires_64bit_ragged_offset &&
!requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) &&
// 9.10.0: known bugs with SDPA FP8
(cudnn_runtime_version != 91000)) {
if (cudnn_runtime_version >= 8900) {
......@@ -213,7 +215,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) &&
((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) &&
!requires_64bit_ragged_offset) {
!requires_64bit_ragged_offset &&
(softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) {
flag_m512 = true;
}
if (
......@@ -363,7 +366,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// check 64-bit ragged offset support
(supported_ragged_offset_size) &&
// 9.10.0/9.10.1: known bugs with SDPA F16
(cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) {
(cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001) &&
// softmax type
// pre-9.13.1: vanilla
// 9.13.1+: vanilla, off-by-one, learnable
(cudnn_runtime_version >= 91301 ||
(cudnn_runtime_version < 91301 &&
softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
......@@ -405,14 +414,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
}
// NVTE fused attention FWD with packed QKV
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
size_t max_seqlen, bool is_training, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) {
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine;
......@@ -421,6 +432,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const Tensor *input_rng_state = convertNVTETensorCheck(rng_state);
const Tensor *input_QKV = convertNVTETensorCheck(QKV);
const Tensor *input_Bias = convertNVTETensorCheck(Bias);
const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset);
Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = convertNVTETensor(workspace);
......@@ -447,8 +459,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h,
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -463,9 +475,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O,
Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace,
stream, handle);
attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, input_Bias,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded,
input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -487,10 +499,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
size_t max_seqlen, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
......@@ -505,6 +518,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
Tensor *input_output_dP = convertNVTETensorCheck(dP);
Tensor *output_dQKV = convertNVTETensorCheck(dQKV);
Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset);
Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_QKV->data.shape.size();
......@@ -529,8 +543,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
max_seqlen, d, d, window_size_left, window_size_right);
true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h,
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -543,19 +557,22 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *input_Bias, *input_rng_state;
size_t i = 0;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_Bias, *input_SoftmaxOffset;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
} else {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO,
input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded,
input_rng_state, wkspace, stream, handle);
softmax_type, window_size_left, window_size_right, deterministic, input_QKV, input_O,
input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQKV, output_dBias,
output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace,
stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
......@@ -580,14 +597,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
}
// NVTE fused attention FWD with packed KV
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream) {
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -600,6 +618,7 @@ void nvte_fused_attn_fwd_kvpacked(
const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_KV = convertNVTETensorCheck(KV);
const Tensor *input_Bias = convertNVTETensorCheck(Bias);
const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset);
Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = convertNVTETensor(workspace);
......@@ -660,8 +679,8 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -677,10 +696,11 @@ void nvte_fused_attn_fwd_kvpacked(
fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right,
input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left,
window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, output_O,
Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -702,12 +722,12 @@ void nvte_fused_attn_fwd_kvpacked(
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream) {
NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -723,6 +743,7 @@ void nvte_fused_attn_bwd_kvpacked(
Tensor *output_dQ = convertNVTETensorCheck(dQ);
Tensor *output_dKV = convertNVTETensorCheck(dKV);
Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset);
Tensor *wkspace = convertNVTETensor(workspace);
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
......@@ -755,8 +776,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -770,20 +791,23 @@ void nvte_fused_attn_bwd_kvpacked(
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8903)
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *input_Bias, *input_rng_state;
size_t i = 0;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_Bias, *input_SoftmaxOffset;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
} else {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q,
input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic,
input_Q, input_KV, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ,
output_dKV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
#else
const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention "
......@@ -809,16 +833,17 @@ void nvte_fused_attn_bwd_kvpacked(
}
// NVTE fused attention FWD with separate Q, K and V
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream) {
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -832,6 +857,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const Tensor *input_K = convertNVTETensorCheck(K);
const Tensor *input_V = convertNVTETensorCheck(V);
const Tensor *input_Bias = convertNVTETensorCheck(Bias);
const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset);
Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = convertNVTETensor(workspace);
......@@ -886,8 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -903,10 +929,11 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right,
input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left,
window_size_right, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O,
Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -928,14 +955,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream) {
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -953,6 +981,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
Tensor *output_dK = convertNVTETensorCheck(dK);
Tensor *output_dV = convertNVTETensorCheck(dV);
Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset);
Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_Q->data.shape.size();
......@@ -978,8 +1007,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -993,19 +1022,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *input_Bias, *input_rng_state;
size_t i = 0;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_Bias, *input_SoftmaxOffset;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
} else {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic,
input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK,
output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right,
deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias,
input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias,
output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
......
......@@ -54,10 +54,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k,
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias,
void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK,
void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrSoftmaxStats,
void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -75,6 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
is_causal = true;
is_bottom_right = false;
}
bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX);
bool is_dropout = (is_training && dropout_probability != 0.0f);
NVTE_QKV_Format q_format = nvte_get_q_format(layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(layout);
......@@ -98,8 +100,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
s_q = is_ragged_q ? max_t_q : s_q;
s_kv = is_ragged_kv ? max_t_kv : s_kv;
}
const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
try {
FADescriptor_v1 descriptor{b,
h,
......@@ -122,11 +124,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout,
bias_type,
mask_type,
softmax_type,
window_size_left,
window_size_right,
true,
tensorType,
tensorType};
cudnn_frontend::DataType_t::NOT_SET,
cudnn_frontend::DataType_t::NOT_SET,
cudnn_frontend::DataType_t::NOT_SET};
namespace fe = cudnn_frontend;
using graph_and_tensors =
......@@ -138,6 +143,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // O
std::shared_ptr<fe::graph::Tensor_attributes>, // Stats
std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // softmax_offset
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // page_table_k
......@@ -168,7 +174,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale, softmax_offset;
std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> page_table_k, page_table_v;
std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o,
......@@ -302,6 +308,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset);
}
if (is_softmax_offset) {
softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("softmax_offset")
.set_dim({1, h, 1, 1})
.set_stride({h, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
sdpa_options.set_sink_token(softmax_offset);
}
auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options);
std::vector<int64_t> o_stride(4);
......@@ -338,6 +353,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O);
auto Stats_tuple = std::make_tuple(Stats);
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto softmax_offset_tuple =
is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr);
auto padding_tuple =
is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto page_table_tuple = is_paged_kv ? std::make_tuple(page_table_k, page_table_v)
......@@ -358,17 +375,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple,
page_table_tuple, offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple);
auto return_tuple =
std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple,
softmax_offset_tuple, padding_tuple, page_table_tuple, offset_qo_tuple,
offset_kv_tuple, offset_s_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, page_table_k, page_table_v,
offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] =
get_graph(sdpa_f16_fprop_cache, descriptor);
auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, softmax_offset, seq_q, seq_kv,
page_table_k, page_table_v, offset_q, offset_o, offset_k, offset_v, offset_stats,
dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor);
// Exit to request upper level API to allocate memory if needed
// n.b. Care should be taken to align each of the added worksapce tensors to their type.
......@@ -473,6 +491,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
variant_pack[dropout_seed] = devPtrDropoutSeed;
variant_pack[dropout_offset] = devPtrDropoutOffset;
}
if (is_softmax_offset) {
variant_pack[softmax_offset] = devPtrSoftmaxOffset;
}
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
......@@ -483,14 +506,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose,
void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias,
void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ,
void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats,
void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV,
void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed,
void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
......@@ -506,6 +529,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
is_causal = true;
is_bottom_right = false;
}
bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX);
bool is_dropout = (dropout_probability != 0.0f);
NVTE_QKV_Format q_format = nvte_get_q_format(layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(layout);
......@@ -558,11 +582,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
layout,
bias_type,
mask_type,
softmax_type,
window_size_left,
window_size_right,
deterministic,
tensorType,
tensorType};
cudnn_frontend::DataType_t::NOT_SET,
cudnn_frontend::DataType_t::NOT_SET,
cudnn_frontend::DataType_t::NOT_SET};
namespace fe = cudnn_frontend;
using graph_and_tensors =
......@@ -579,6 +606,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // dV
std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // dBias
std::shared_ptr<fe::graph::Tensor_attributes>, // softmax_offset
std::shared_ptr<fe::graph::Tensor_attributes>, // d_softmax_offset
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_q
......@@ -608,7 +637,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_compute_data_type(fe::DataType_t::FLOAT);
std::shared_ptr<fe::graph::Tensor_attributes> q, k, v, o, dO, stats, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> bias, dBias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> bias, dBias, softmax_offset, d_softmax_offset,
seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o,
offset_stats;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
......@@ -771,6 +801,21 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset);
}
if (is_softmax_offset) {
softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("softmax_offset")
.set_dim({1, h, 1, 1})
.set_stride({h, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
sdpa_backward_options.set_sink_token(softmax_offset);
d_softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("d_softmax_offset")
.set_dim({1, h, 1, 1})
.set_stride({h, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
sdpa_backward_options.set_dsink_token(d_softmax_offset);
}
auto [dQ, dK, dV] = mha_graph->sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options);
dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride);
......@@ -796,6 +841,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>> // dV
key_tensors_tuple = std::make_tuple(q, k, v, o, dO, stats, attn_scale, dQ, dK, dV);
auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr);
auto softmax_offset_tuple = is_softmax_offset
? std::make_tuple(softmax_offset, d_softmax_offset)
: std::make_tuple(nullptr, nullptr);
auto padding_tuple =
is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto offset_qo_tuple =
......@@ -814,17 +862,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
auto return_tuple =
std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple,
offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple);
auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple,
softmax_offset_tuple, padding_tuple, offset_qo_tuple,
offset_kv_tuple, offset_s_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv,
offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] =
get_graph(sdpa_f16_bprop_cache, descriptor);
auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, softmax_offset,
d_softmax_offset, seq_q, seq_kv, offset_q, offset_o, offset_k, offset_v, offset_stats,
dropout_seed, dropout_offset] = get_graph(sdpa_f16_bprop_cache, descriptor);
// Exit to request upper level API to allocate memory if needed
// n.b. Care should be taken to align each of the added worksapce tensors to their type.
......@@ -938,6 +986,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
variant_pack[dropout_offset] = devPtrDropoutOffset;
}
if (is_softmax_offset) {
variant_pack[softmax_offset] = devPtrSoftmaxOffset;
variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset;
}
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
......@@ -949,8 +1002,9 @@ using namespace transformer_engine::fused_attn;
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -977,6 +1031,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
void *devPtrSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
}
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
......@@ -990,11 +1048,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
max_tokens = get_max_tokens(num_tokens);
}
size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens, num_attn_heads, 1};
......@@ -1002,41 +1059,39 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = nullptr;
output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1};
output_softmax_offset->data.dtype = DType::kFloat32;
}
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Aux_CTX_Tensors->size = i;
} else if (Aux_CTX_Tensors->size >= 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = devPtrBias;
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = devPtrSoftmaxOffset;
}
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
......@@ -1050,11 +1105,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim,
max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets,
devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream,
handle);
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS,
devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr,
nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1074,9 +1129,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -1122,6 +1178,12 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrSoftmaxOffset = nullptr;
void *devPtrdSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr;
}
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr;
......@@ -1135,11 +1197,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim,
max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, devPtrK,
devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens,
devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
bias_type, mask_type, softmax_type, window_size_left, window_size_right, deterministic,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1161,12 +1223,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
......@@ -1192,6 +1254,10 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
void *devPtrSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
}
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
......@@ -1216,11 +1282,10 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
max_tokens_kv = get_max_tokens(num_tokens_kv);
}
size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
......@@ -1228,41 +1293,39 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = nullptr;
output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1};
output_softmax_offset->data.dtype = DType::kFloat32;
}
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Aux_CTX_Tensors->size = i;
} else if (Aux_CTX_Tensors->size >= 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = devPtrBias;
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = devPtrSoftmaxOffset;
}
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
......@@ -1277,11 +1340,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim,
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS,
devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1302,10 +1365,11 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
......@@ -1359,6 +1423,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrSoftmaxOffset = nullptr;
void *devPtrdSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr;
}
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
......@@ -1374,9 +1444,10 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ,
devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV,
devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
......@@ -1401,12 +1472,13 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
......@@ -1425,6 +1497,10 @@ void fused_attn_arbitrary_seqlen_fwd(
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
void *devPtrSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
}
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
......@@ -1446,11 +1522,10 @@ void fused_attn_arbitrary_seqlen_fwd(
max_tokens_kv = get_max_tokens(num_tokens_kv);
}
size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
......@@ -1458,41 +1533,39 @@ void fused_attn_arbitrary_seqlen_fwd(
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = nullptr;
output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1};
output_softmax_offset->data.dtype = DType::kFloat32;
}
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Aux_CTX_Tensors->size = i;
} else if (Aux_CTX_Tensors->size >= 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = devPtrBias;
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = devPtrSoftmaxOffset;
}
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
......@@ -1507,11 +1580,11 @@ void fused_attn_arbitrary_seqlen_fwd(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS,
devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1532,13 +1605,14 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
......@@ -1577,6 +1651,12 @@ void fused_attn_arbitrary_seqlen_bwd(
void *devPtrdV = output_dV->data.dptr;
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrSoftmaxOffset = nullptr;
void *devPtrdSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr;
}
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
......@@ -1592,9 +1672,10 @@ void fused_attn_arbitrary_seqlen_bwd(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ,
devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV,
devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
......
......@@ -21,17 +21,19 @@ namespace transformer_engine {
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
......@@ -41,21 +43,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
......@@ -66,24 +69,26 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine
......
......@@ -1658,8 +1658,9 @@ void fused_attn_fp8_fwd_impl_v1(
void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK,
void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO,
void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV,
void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type,
void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type,
cudnn_frontend::DataType_t o_tensor_type, void* workspace, size_t* workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
......@@ -1672,6 +1673,13 @@ void fused_attn_fp8_fwd_impl_v1(
auto bias_h = h;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF ||
o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16);
bool is_delayed_scaling = (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 ||
o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2);
NVTE_CHECK(is_current_scaling || is_delayed_scaling,
"FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or "
"kFloat8E5M2!");
try {
FADescriptor_v1 descriptor{b,
......@@ -1695,11 +1703,14 @@ void fused_attn_fp8_fwd_impl_v1(
layout,
bias_type,
mask_type,
NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0,
0,
true,
fwd_tensor_type,
fwd_tensor_type};
qkv_tensor_type,
o_tensor_type,
cudnn_frontend::DataType_t::NOT_SET,
cudnn_frontend::DataType_t::NOT_SET};
namespace fe = cudnn_frontend;
using graph_and_tensors =
......@@ -1738,7 +1749,7 @@ void fused_attn_fp8_fwd_impl_v1(
// otherwise, build the op_graph and the plan. Then update cache
auto mha_graph = std::make_shared<fe::graph::Graph>();
mha_graph->set_io_data_type(fwd_tensor_type)
mha_graph->set_io_data_type(qkv_tensor_type)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
......@@ -1786,7 +1797,13 @@ void fused_attn_fp8_fwd_impl_v1(
descale_v = mha_graph->tensor_like(descale_q, "Descale_V");
descale_s = mha_graph->tensor_like(descale_q, "Descale_S");
scale_s = mha_graph->tensor_like(descale_q, "Scale_S");
if (is_delayed_scaling) {
scale_o = mha_graph->tensor_like(descale_q, "Scale_O");
}
if (is_current_scaling) {
scale_o = mha_graph->tensor(1.0f);
}
fe::graph::SDPA_fp8_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_fp8_attributes()
......@@ -1838,11 +1855,12 @@ void fused_attn_fp8_fwd_impl_v1(
std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_O_Matrix);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride).set_data_type(o_tensor_type);
amax_o->set_output(true)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
amax_s->set_output(true)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
......@@ -1915,13 +1933,16 @@ void fused_attn_fp8_fwd_impl_v1(
{descale_v, devPtrDescaleV},
{descale_s, devPtrDescaleS},
{scale_s, devPtrScaleS},
{scale_o, devPtrScaleO},
{attn_scale, &scaling_factor},
{O, devPtrO},
{amax_s, devPtrAmaxS},
{amax_o, devPtrAmaxO},
{Stats, devPtrM}};
if (is_delayed_scaling) {
variant_pack[scale_o] = devPtrScaleO;
}
/* if (is_bias) {
variant_pack[bias] = devPtrBias;
} */
......@@ -1962,8 +1983,9 @@ void fused_attn_fp8_bwd_impl_v1(
void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV,
void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV,
void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed,
void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type,
cudnn_frontend::DataType_t bwd_tensor_type, void* workspace, size_t* workspace_size,
void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type,
cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type,
cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
......@@ -1977,6 +1999,15 @@ void fused_attn_fp8_bwd_impl_v1(
auto bias_h = h;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF ||
dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16);
bool is_delayed_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 ||
dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2);
NVTE_CHECK(is_current_scaling || is_delayed_scaling,
"FP8 fused attention only supports dQKV tensor in kFloat16, kBFloat16, kFloat8E4M3 or "
"kFloat8E5M2!");
bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF ||
o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16);
try {
FADescriptor_v1 descriptor{b,
......@@ -2000,11 +2031,14 @@ void fused_attn_fp8_bwd_impl_v1(
layout,
bias_type,
mask_type,
NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0,
0,
false,
fwd_tensor_type,
bwd_tensor_type};
qkv_tensor_type,
o_tensor_type,
do_tensor_type,
dqkv_tensor_type};
namespace fe = cudnn_frontend;
using graph_and_tensors =
......@@ -2057,7 +2091,7 @@ void fused_attn_fp8_bwd_impl_v1(
// otherwise, build the op_graph and the plan. Then update cache
auto mha_graph = std::make_shared<fe::graph::Graph>();
mha_graph->set_io_data_type(fwd_tensor_type)
mha_graph->set_io_data_type(qkv_tensor_type)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
......@@ -2097,7 +2131,8 @@ void fused_attn_fp8_bwd_impl_v1(
o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim({b, h, s_q, d})
.set_stride(o_stride));
.set_stride(o_stride)
.set_data_type(o_tensor_type));
dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO")
.set_dim({b, h, s_q, d})
......@@ -2123,14 +2158,26 @@ void fused_attn_fp8_bwd_impl_v1(
descale_k = mha_graph->tensor_like(descale_q, "Descale_q");
descale_v = mha_graph->tensor_like(descale_q, "Descale_V");
descale_s = mha_graph->tensor_like(descale_q, "Descale_S");
descale_o = mha_graph->tensor_like(descale_q, "Descale_O");
descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP");
if (is_O_in_F16) {
descale_o = mha_graph->tensor(1.0f);
} else {
descale_o = mha_graph->tensor_like(descale_q, "Descale_O");
}
descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO");
scale_s = mha_graph->tensor_like(descale_q, "Scale_S");
scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP");
if (is_delayed_scaling) {
scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ");
scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK");
scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV");
}
if (is_current_scaling) {
scale_dQ = mha_graph->tensor(1.0f);
scale_dK = mha_graph->tensor(1.0f);
scale_dV = mha_graph->tensor(1.0f);
}
fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options;
sdpa_backward_options = fe::graph::SDPA_fp8_backward_attributes()
......@@ -2212,10 +2259,10 @@ void fused_attn_fp8_bwd_impl_v1(
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
dO->set_data_type(bwd_tensor_type);
dQ->set_data_type(bwd_tensor_type);
dK->set_data_type(bwd_tensor_type);
dV->set_data_type(bwd_tensor_type);
dO->set_data_type(do_tensor_type);
dQ->set_data_type(dqkv_tensor_type);
dK->set_data_type(dqkv_tensor_type);
dV->set_data_type(dqkv_tensor_type);
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // q
std::shared_ptr<fe::graph::Tensor_attributes>, // k
......@@ -2296,14 +2343,10 @@ void fused_attn_fp8_bwd_impl_v1(
{descale_q, devPtrDescaleQ},
{descale_k, devPtrDescaleK},
{descale_v, devPtrDescaleV},
{descale_o, devPtrDescaleO},
{descale_dO, devPtrDescaledO},
{descale_s, devPtrDescaleS},
{descale_dP, devPtrDescaledP},
{scale_s, devPtrScaleS},
{scale_dQ, devPtrScaledQ},
{scale_dK, devPtrScaledK},
{scale_dV, devPtrScaledV},
{scale_dP, devPtrScaledP},
{dQ, devPtrdQ},
{dK, devPtrdK},
......@@ -2314,6 +2357,15 @@ void fused_attn_fp8_bwd_impl_v1(
{amax_dP, devPtrAmaxdP},
};
if (is_delayed_scaling) {
variant_pack[scale_dQ] = devPtrScaledQ;
variant_pack[scale_dK] = devPtrScaledK;
variant_pack[scale_dV] = devPtrScaledV;
}
if (!is_O_in_F16) {
variant_pack[descale_o] = devPtrDescaleO;
}
/* if (is_bias) {
variant_pack[bias] = devPtrBias;
if ((bias_b == 1) && (bias_h == h)) {
......@@ -2364,6 +2416,7 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_QKV->data.dtype;
const DType O_type = output_O->data.dtype;
void* devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
......@@ -2430,8 +2483,8 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM,
devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS,
devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl(
batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout,
......@@ -2465,6 +2518,7 @@ void fused_attn_fp8_bwd_qkvpacked(
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_QKV->data.dtype;
const DType dO_type = input_dO->data.dtype;
const DType dQKV_type = output_dQKV->data.dtype;
void* devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
......@@ -2482,7 +2536,11 @@ void fused_attn_fp8_bwd_qkvpacked(
void* devPtrDescaleV = input_QKV->scale_inv.dptr;
void* devPtrO = input_O->data.dptr;
void* devPtrDescaleO = input_O->scale_inv.dptr;
const DType O_type = input_O->data.dtype;
void* devPtrDescaleO = nullptr;
if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) {
devPtrDescaleO = input_O->scale_inv.dptr;
}
void* devPtrdO = input_dO->data.dptr;
void* devPtrDescaledO = input_dO->scale_inv.dptr;
......@@ -2525,7 +2583,8 @@ void fused_attn_fp8_bwd_qkvpacked(
devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle);
get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl(
batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout,
......@@ -2563,6 +2622,7 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
const DType O_type = output_O->data.dtype;
void* devPtrQ = input_Q->data.dptr;
void* devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
......@@ -2631,8 +2691,8 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM,
devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS,
devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale,
......@@ -2669,6 +2729,7 @@ void fused_attn_fp8_bwd_kvpacked(
cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
const DType dO_type = input_dO->data.dtype;
const DType dQKV_type = output_dQ->data.dtype;
void* devPtrQ = input_Q->data.dptr;
void* devPtrKV = input_KV->data.dptr;
......@@ -2686,7 +2747,11 @@ void fused_attn_fp8_bwd_kvpacked(
void* devPtrDescaleV = input_KV->scale_inv.dptr;
void* devPtrO = input_O->data.dptr;
void* devPtrDescaleO = input_O->scale_inv.dptr;
const DType O_type = input_O->data.dtype;
void* devPtrDescaleO = nullptr;
if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) {
devPtrDescaleO = input_O->scale_inv.dptr;
}
void* devPtrdO = input_dO->data.dptr;
void* devPtrDescaledO = input_dO->scale_inv.dptr;
......@@ -2731,7 +2796,8 @@ void fused_attn_fp8_bwd_kvpacked(
devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle);
get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout,
......@@ -2820,6 +2886,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
reinterpret_cast<void*>(reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_Q->data.dtype;
const DType O_type = output_O->data.dtype;
size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
......@@ -2829,8 +2896,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM,
devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS,
devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale,
......@@ -2876,7 +2943,11 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
void* devPtrDescaleV = input_Q->scale_inv.dptr;
void* devPtrO = input_O->data.dptr;
void* devPtrDescaleO = input_O->scale_inv.dptr;
const DType O_type = input_O->data.dtype;
void* devPtrDescaleO = nullptr;
if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) {
devPtrDescaleO = input_O->scale_inv.dptr;
}
void* devPtrdO = input_dO->data.dptr;
void* devPtrDescaledO = input_dO->scale_inv.dptr;
......@@ -2909,6 +2980,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
reinterpret_cast<void*>(reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_Q->data.dtype;
const DType dO_type = input_dO->data.dtype;
const DType dQKV_type = output_dQ->data.dtype;
size_t workspace_size = 0;
......@@ -2922,7 +2994,8 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle);
get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout,
......
......@@ -107,23 +107,28 @@ struct FADescriptor_v1 {
NVTE_QKV_Layout layout;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
NVTE_Softmax_Type softmax_type;
std::int64_t window_size_left;
std::int64_t window_size_right;
bool deterministic;
cudnn_frontend::DataType_t fwd_tensor_type;
cudnn_frontend::DataType_t bwd_tensor_type;
cudnn_frontend::DataType_t qkv_tensor_type;
cudnn_frontend::DataType_t o_tensor_type;
cudnn_frontend::DataType_t do_tensor_type;
cudnn_frontend::DataType_t dqkv_tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left,
window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) <
attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type,
o_tensor_type, do_tensor_type, dqkv_tensor_type) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left,
rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type,
rhs.bwd_tensor_type);
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,
rhs.dqkv_tensor_type);
}
};
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "./config.h"
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#include <cstring>
#include "../util/logging.h"
NVTEMatmulConfig nvte_create_matmul_config() { return new transformer_engine::MatmulConfig; }
void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
void *buf, size_t size_in_bytes, size_t *size_written) {
// Write attribute size
NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ",
static_cast<int>(attr), ")");
NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)");
const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr];
*size_written = attr_size;
// Return immediately if buffer is not provided
if (buf == nullptr) {
return;
}
// Check buffer size
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for matmul config attribute "
"(attribute ",
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");
// Write to buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)");
const auto &config_ = *reinterpret_cast<const transformer_engine::MatmulConfig *>(config);
switch (attr) {
case kNVTEMatmulConfigBiasTensor:
std::memcpy(buf, &config_.bias_tensor, attr_size);
break;
case kNVTEMatmulConfigDBiasTensor:
std::memcpy(buf, &config_.dbias_tensor, attr_size);
break;
case kNVTEMatmulConfigWithGELUEpilogue:
std::memcpy(buf, &config_.with_gelu_epilogue, attr_size);
break;
case kNVTEMatmulConfigWithDGELUEpilogue:
std::memcpy(buf, &config_.with_dgelu_epilogue, attr_size);
break;
case kNVTEMatmulConfigEpilogueAuxTensor:
std::memcpy(buf, &config_.epilogue_aux_tensor, attr_size);
break;
case kNVTEMatmulConfigUseSplitAccumulator:
std::memcpy(buf, &config_.use_split_accumulator, attr_size);
break;
case kNVTEMatmulConfigSMCount:
std::memcpy(buf, &config_.sm_count, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
}
}
void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
const void *buf, size_t size_in_bytes) {
// Check attribute and buffer
NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ",
static_cast<int>(attr), ")");
const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr];
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for matmul config attribute "
"(attribute ",
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");
NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");
// Read from buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)");
auto &config_ = *reinterpret_cast<transformer_engine::MatmulConfig *>(config);
switch (attr) {
case kNVTEMatmulConfigBiasTensor:
std::memcpy(&config_.bias_tensor, buf, attr_size);
break;
case kNVTEMatmulConfigDBiasTensor:
std::memcpy(&config_.dbias_tensor, buf, attr_size);
break;
case kNVTEMatmulConfigWithGELUEpilogue:
std::memcpy(&config_.with_gelu_epilogue, buf, attr_size);
break;
case kNVTEMatmulConfigWithDGELUEpilogue:
std::memcpy(&config_.with_dgelu_epilogue, buf, attr_size);
break;
case kNVTEMatmulConfigEpilogueAuxTensor:
std::memcpy(&config_.epilogue_aux_tensor, buf, attr_size);
break;
case kNVTEMatmulConfigUseSplitAccumulator:
std::memcpy(&config_.use_split_accumulator, buf, attr_size);
break;
case kNVTEMatmulConfigSMCount:
std::memcpy(&config_.sm_count, buf, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
}
}
void nvte_destroy_matmul_config(NVTEMatmulConfig config) {
if (config != nullptr) {
delete reinterpret_cast<transformer_engine::MatmulConfig *>(config);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment