Unverified Commit 5e4e0b2c authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Add sink attention support from cuDNN (#2148)



* first draft; debug plan failure
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* debug uid error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak params
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add grad in output
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up prints
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix prints in test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* address review comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix unfused grad; add softmax_type; add sink to bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix padding mask; add swa tests; remove requires_grad for off-by-one
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix indent
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix non-determinism and shapes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up prints
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add GQA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add CP A2A; dq/dk mismatches
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CP A2A; need cleaner solution
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CP A2A; pending cudnn kernel change
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix world size in unit test; avoid thd format
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix kernel_backend, dtype in unit test; fix head_dim for FP8 Hopper
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix thd logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 context
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak CP logging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* allow no_mask/padding for SWA(left,0)
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "allow no_mask/padding for SWA(left,0)"

This reverts commit 08b4ccc67a08b6882080b06aa715f541bb832aca.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add softmax_type to Jax
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add cuDNN version control
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* prettify tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* skip 9.13 for MLA, non 192/128
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename compare_with_error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* small cleanups and improvements
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix minor CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force sink/dsink to be float32
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* switch FE to GH FE
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* return to GH TE main FE commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FE to 1.14.1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up before CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* bump up cudnn version
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add backend selection guard for unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add docstring for softmax type enums in C
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 57b4d7bc
Subproject commit deda80e5372d50e925d7bf4f76c5db779be3fbd5 Subproject commit 1a7b4b78db44712fb9707d21cd2e3179f1fd88b8
...@@ -17,88 +17,18 @@ from test_attention_with_cp import model_configs_flash_attn, model_configs_fused ...@@ -17,88 +17,18 @@ from test_attention_with_cp import model_configs_flash_attn, model_configs_fused
from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
from utils import ModelConfig, compare_and_assert
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
def run_dpa_with_cp( def generate_input_shapes(
dtype="bf16", qkv_format: str,
model=None, config: ModelConfig,
qkv_format="bshd", world_size: int,
kernel_backend="FlashAttention", kernel_backend: str,
cp_comm_type="p2p",
fp8_mha=False,
): ):
"""Test DotProductAttention module with context parallelism"""
# args are passed as strings
fp8_mha = fp8_mha == "True"
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)
print(f"[INFO] world_size:{world_size}, rank:{rank}")
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
# create flash attn comm group for CP
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert (
world_size % 2 == 0
), "Assuming CP size for A2A is 2, and CP size for P2P is (world_size // 2)!"
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)
if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha)
# instantiate core attn module
core_attn = DotProductAttention(
config.num_heads,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
)
core_attn = core_attn.cuda()
# create flash attn inputs
if qkv_format == "bshd": if qkv_format == "bshd":
q_input_shape = ( q_input_shape = (
config.batch_size, config.batch_size,
...@@ -191,34 +121,158 @@ def run_dpa_with_cp( ...@@ -191,34 +121,158 @@ def run_dpa_with_cp(
cu_seqlens_kv = cu_seqlens_q cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded cu_seqlens_kv_padded = cu_seqlens_q_padded
else: else:
assert False, f"{qkv_format} is an unsupported qkv_format!" assert False, f"{qkv_format=} is not supported!"
return (
q_input_shape,
k_input_shape,
v_input_shape,
attn_output_shape,
cu_seqlens_q,
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
)
def get_tols(config, dtype):
if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
atol = 2.5e-2
rtol = 2.5e-2
else:
atol = 3.5e-2
rtol = 3.5e-2
rmse_tol = 0.01
elif dtype == "fp16":
atol = 5e-3
rtol = 5e-3
rmse_tol = 0.01
elif dtype == "fp8":
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.1
else:
assert False, f"{dtype=} is not supported!"
return atol, rtol, rmse_tol
def run_dpa_with_cp(
dtype="bf16",
model=None,
qkv_format="bshd",
kernel_backend="FlashAttention",
cp_comm_type="p2p",
fp8_mha=False,
log_level=logging.WARNING,
):
"""Test DotProductAttention module with context parallelism"""
logging.root.setLevel(log_level)
# set up environment variables and config
fp8_mha = fp8_mha == "True"
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type=} is not supported!"
if qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
# set up distributed group
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)
logging.info(f"[Rank {rank}] Setup: world_size {world_size}")
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
# set up communication group for CP
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert world_size % 2 == 0, (
"{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size"
" = 2."
)
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)
if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha)
# instantiate attention module
core_attn = DotProductAttention(
config.num_heads,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
softmax_type=config.softmax_type,
).cuda()
if config.softmax_type != "vanilla":
core_attn.softmax_offset.requires_grad = True
# generate attention inputs
(
q_input_shape,
k_input_shape,
v_input_shape,
attn_output_shape,
cu_seqlens_q,
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend)
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda() q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda() k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda() v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda()
for x in [q, k, v]:
x.requires_grad = True
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda() dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
dout_quantizer = Float8Quantizer( if fp8_mha:
fp8_dtype=tex.DType.kFloat8E5M2, dout_quantizer = Float8Quantizer(
scale=torch.tensor([1], dtype=torch.float32).cuda(), fp8_dtype=tex.DType.kFloat8E5M2,
amax=torch.tensor([0], dtype=torch.float32).cuda(), scale=torch.tensor([1], dtype=torch.float32).cuda(),
) amax=torch.tensor([0], dtype=torch.float32).cuda(),
)
# create flash attention bias
if config.attn_bias_type not in ["no_bias", "alibi"]: if config.attn_bias_type not in ["no_bias", "alibi"]:
attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv)
bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda()
else: else:
bias = None bias = None
# run core_attn without CP ############ run without CP ############
for x in [q, k, v]: logging.info(f"[Rank {rank}] Run without context parallelism")
x.requires_grad = True
if dtype == "fp8": if dtype == "fp8":
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else: else:
fp8_context = nullcontext() fp8_context = nullcontext()
with fp8_context: with fp8_context:
out = core_attn( out = core_attn(
q, q,
...@@ -236,8 +290,30 @@ def run_dpa_with_cp( ...@@ -236,8 +290,30 @@ def run_dpa_with_cp(
out.backward(dout_fp8) out.backward(dout_fp8)
else: else:
out.backward(dout) out.backward(dout)
dq, dk, dv = q.grad, k.grad, v.grad
d_softmax_offset = None
if config.softmax_type != "vanilla":
d_softmax_offset = core_attn.softmax_offset.grad
# run core_attn wit CP ############ run with CP ############
logging.info(f"[Rank {rank}] Run with context parallelism")
# set up environment
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)
if config.softmax_type != "vanilla":
core_attn.softmax_offset.grad.zero_()
if dtype == "fp8":
core_attn.reset_fp8_meta_tensors()
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
# set up inputs
q_, k_, v_, dout_, *rest = [ q_, k_, v_, dout_, *rest = [
x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias]) x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])
] ]
...@@ -267,8 +343,6 @@ def run_dpa_with_cp( ...@@ -267,8 +343,6 @@ def run_dpa_with_cp(
) )
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None: if bias_ is not None:
bias_ = bias_.view( bias_ = bias_.view(
...@@ -276,19 +350,8 @@ def run_dpa_with_cp( ...@@ -276,19 +350,8 @@ def run_dpa_with_cp(
) )
bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)
if dtype == "fp8":
core_attn.reset_fp8_meta_tensors()
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
# run attention
with fp8_context: with fp8_context:
out_ = core_attn( out_ = core_attn(
q_, q_,
...@@ -306,18 +369,23 @@ def run_dpa_with_cp( ...@@ -306,18 +369,23 @@ def run_dpa_with_cp(
out_.backward(dout_fp8_) out_.backward(dout_fp8_)
else: else:
out_.backward(dout_) out_.backward(dout_)
if fp8_mha: if fp8_mha:
assert isinstance(out, Float8Tensor) assert isinstance(out, Float8Tensor)
assert isinstance(out_, Float8Tensor) assert isinstance(out_, Float8Tensor)
out = out.dequantize() out = out.dequantize()
out_ = out_.dequantize() out_ = out_.dequantize()
for x in [out_, q_.grad, k_.grad, v_.grad]: # get outputs
assert torch.all(~torch.isnan(x)) dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad
assert torch.all(~torch.isinf(x)) d_softmax_offset_ = None
if config.softmax_type != "vanilla":
# compare results with and without CP d_softmax_offset_ = core_attn.softmax_offset.grad.clone()
for x in [out_, dq_, dk_, dv_, d_softmax_offset_]:
if x is not None:
assert torch.all(~torch.isnan(x))
assert torch.all(~torch.isinf(x))
############ compare results between CP and no-CP ############
if qkv_format == "bshd" or qkv_format == "sbhd": if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [ dq, dk, dv, out = [
x.view( x.view(
...@@ -373,56 +441,70 @@ def run_dpa_with_cp( ...@@ -373,56 +441,70 @@ def run_dpa_with_cp(
).item() ).item()
== 0 == 0
) )
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=3.5e-2, rtol=3.5e-2)
elif dtype == "fp16":
tols = dict(atol=5e-3, rtol=5e-3)
elif dtype == "fp8":
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
else:
assert False, f"{dtype} is an unsupported dtype!"
def _rmse(a, b):
return torch.sqrt((a - b).square().mean()).item()
def _error(a, b):
if dtype != "fp8":
torch.testing.assert_close(a, b, **tols)
else:
try:
torch.testing.assert_close(a, b, **tols)
except Exception as e:
logging.debug(e)
rmse = _rmse(a, b)
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert (
rmse < rmse_tol * rmse_range
), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
if qkv_format == "bshd": atol, rtol, rmse_tol = get_tols(config, dtype)
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_]
_error(a[:, 0], b[:, 0]) tensors_no_cp = [out, dq, dk, dv, d_softmax_offset]
_error(a[:, 1], b[:, 1]) names = ["out", "dq", "dk", "dv", "d_softmax_offset"]
elif qkv_format == "sbhd": names_cp = [x + "_cp" for x in names]
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): names_no_cp = [x + "_no_cp" for x in names]
_error(a[0], b[0]) is_fp8 = dtype == "fp8"
_error(a[1], b[1]) for i, t in enumerate(tensors_no_cp):
elif qkv_format == "thd": if t is not None:
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): if "softmax_offset" not in names[i]:
_error(a, b) if qkv_format == "bshd":
else: compare_and_assert(
assert False, f"{qkv_format} is an unsupported qkv_format!" t[:, 0],
tensors_cp[i][:, 0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[:, 1],
tensors_cp[i][:, 1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
elif qkv_format == "sbhd":
compare_and_assert(
t[0],
tensors_cp[i][0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[1],
tensors_cp[i][1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
elif qkv_format == "thd":
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
)
else:
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
)
logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches")
# destroy distribution group
dist.destroy_process_group() dist.destroy_process_group()
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import logging import logging
import math
import os import os
import sys import sys
import pathlib import pathlib
...@@ -50,27 +49,35 @@ _current_file = pathlib.Path(__file__).resolve() ...@@ -50,27 +49,35 @@ _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent)) sys.path.append(str(_current_file.parent.parent))
from utils import ( from utils import (
reset_rng_states, reset_rng_states,
compare_and_assert,
ModelConfig, ModelConfig,
dtype_tols, dtype_tols,
get_available_attention_backends, get_available_attention_backends,
) )
# Only run FP8 tests on H100 # Check if hardware supports FP8
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
# Reset RNG seed and states
seed = 1234 seed = 1234
# Reset RNG states
reset_rng_states() reset_rng_states()
# Reset FP8 global state manager
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_global_fp8_state(): def reset_global_fp8_state():
yield yield
fp8.FP8GlobalStateManager.reset() fp8.FP8GlobalStateManager.reset()
# Define F16 data types to test
param_types = [torch.float16]
if is_bf16_compatible():
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
model_configs_base = { model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: ModelConfig(b, sq, hq, dqk)
"base_1_0": ModelConfig(8, 128, 16, 64), "base_1_0": ModelConfig(8, 128, 16, 64),
"base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256), "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
"base_2_0": ModelConfig(2, 2048, 24, 128), "base_2_0": ModelConfig(2, 2048, 24, 128),
...@@ -86,12 +93,6 @@ model_configs_base = { ...@@ -86,12 +93,6 @@ model_configs_base = {
} }
param_types = [torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base]) @pytest.mark.parametrize("model_configs", [model_configs_base])
...@@ -125,12 +126,12 @@ def test_dot_product_attention( ...@@ -125,12 +126,12 @@ def test_dot_product_attention(
config.window_size = [2, 2] config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
# Get backends
is_training = True is_training = True
available_backends, _, fused_attn_backends = get_available_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
is_training=is_training, is_training=is_training,
) )
...@@ -141,7 +142,6 @@ def test_dot_product_attention( ...@@ -141,7 +142,6 @@ def test_dot_product_attention(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
is_training=is_training, is_training=is_training,
) )
...@@ -227,6 +227,7 @@ def test_dot_product_attention( ...@@ -227,6 +227,7 @@ def test_dot_product_attention(
is_training, is_training,
) )
# Compare results
logging.info(f"[test_dot_product_attention]: is_training = {is_training}") logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
if unfused_attn_supported and flash_attn_supported: if unfused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs flash attn") logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
...@@ -259,23 +260,102 @@ def test_dpa_checkpoint(dtype, model_configs, model): ...@@ -259,23 +260,102 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
"softmax_1_1": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="off-by-one"),
"softmax_1_2": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="learnable"),
"softmax_2_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
"softmax_2_1": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
),
"softmax_2_2": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
),
"softmax_3_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding"),
"softmax_3_1": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="off-by-one"
),
"softmax_3_2": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="learnable"
),
"softmax_4_0": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="causal"
),
"softmax_4_1": ModelConfig(
2,
2048,
64,
64,
num_gqa_groups=8,
window_size=(128, 0),
attn_mask_type="causal",
softmax_type="off-by-one",
),
"softmax_4_2": ModelConfig(
2,
2048,
64,
64,
num_gqa_groups=8,
window_size=(128, 0),
attn_mask_type="causal",
softmax_type="learnable",
),
"softmax_5_0": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="padding_causal"
),
"softmax_5_1": ModelConfig(
2,
2048,
64,
64,
num_gqa_groups=8,
window_size=(128, 0),
attn_mask_type="padding_causal",
softmax_type="off-by-one",
),
"softmax_5_2": ModelConfig(
2,
2048,
64,
64,
num_gqa_groups=8,
window_size=(128, 0),
attn_mask_type="padding_causal",
softmax_type="learnable",
),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax(dtype, model_configs, model):
"""Test DotProductAttention module with different softmax types"""
test_dot_product_attention(
dtype, model_configs, model, True, True, "bshd_bshd_bshd", False, False
)
model_configs_mla = { model_configs_mla = {
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend # test: ModelConfig(b, sq, hq, dqk)
"mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0 "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128),
"mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0 "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),
"mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0 "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128),
"mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1 "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64),
"mla_2_1": ModelConfig( "mla_2_1": ModelConfig(
1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
), # cross, 1 ),
"mla_2_2": ModelConfig( "mla_2_2": ModelConfig(
1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
), # cross, 1 ),
"mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64),
"mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128),
"mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128),
"mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128),
"mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160),
} }
...@@ -289,7 +369,7 @@ def test_dpa_mla(dtype, model_configs, model): ...@@ -289,7 +369,7 @@ def test_dpa_mla(dtype, model_configs, model):
model_configs_mask = { model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: ModelConfig(b, sq, hq, dqk)
"mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"), "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
"mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"), "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"),
"mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"), "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
...@@ -344,18 +424,16 @@ def test_dpa_mask(dtype, model_configs, model): ...@@ -344,18 +424,16 @@ def test_dpa_mask(dtype, model_configs, model):
model_configs_bias = { model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: ModelConfig(b, sq, hq, dqk)
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"), "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
"bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"), "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
"bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"), "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
"bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), # skipped "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"),
"bias_1_5": ModelConfig( "bias_1_5": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"),
2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"
), # skipped
"bias_2_0": ModelConfig( "bias_2_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias" 4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
), # skipped ),
"bias_2_1": ModelConfig( "bias_2_1": ModelConfig(
2, 2,
128, 128,
...@@ -364,10 +442,10 @@ model_configs_bias = { ...@@ -364,10 +442,10 @@ model_configs_bias = {
max_seqlen_kv=256, max_seqlen_kv=256,
attn_mask_type="padding", attn_mask_type="padding",
attn_bias_type="post_scale_bias", attn_bias_type="post_scale_bias",
), # skipped ),
"bias_2_2": ModelConfig( "bias_2_2": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias" 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias"
), # skipped ),
"bias_2_3": ModelConfig( "bias_2_3": ModelConfig(
2, 2,
2048, 2048,
...@@ -376,13 +454,11 @@ model_configs_bias = { ...@@ -376,13 +454,11 @@ model_configs_bias = {
max_seqlen_kv=4096, max_seqlen_kv=4096,
attn_mask_type="padding", attn_mask_type="padding",
attn_bias_type="post_scale_bias", attn_bias_type="post_scale_bias",
), # skipped ),
"bias_2_4": ModelConfig( "bias_2_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"),
4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"
), # skipped
"bias_2_5": ModelConfig( "bias_2_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi" 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi"
), # skipped ),
"bias_3_0": ModelConfig( "bias_3_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), ),
...@@ -400,14 +476,14 @@ model_configs_bias = { ...@@ -400,14 +476,14 @@ model_configs_bias = {
max_seqlen_kv=4096, max_seqlen_kv=4096,
attn_mask_type="causal", attn_mask_type="causal",
attn_bias_type="post_scale_bias", attn_bias_type="post_scale_bias",
), # skipped ),
"bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"), "bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"),
"bias_3_5": ModelConfig( "bias_3_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi" 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi"
), # skipped ),
"bias_4_0": ModelConfig( "bias_4_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias" 4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
), # skipped ),
"bias_4_1": ModelConfig( "bias_4_1": ModelConfig(
2, 2,
128, 128,
...@@ -416,10 +492,10 @@ model_configs_bias = { ...@@ -416,10 +492,10 @@ model_configs_bias = {
max_seqlen_kv=256, max_seqlen_kv=256,
attn_mask_type="padding_causal", attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias", attn_bias_type="post_scale_bias",
), # skipped ),
"bias_4_2": ModelConfig( "bias_4_2": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias" 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
), # skipped ),
"bias_4_3": ModelConfig( "bias_4_3": ModelConfig(
2, 2,
2048, 2048,
...@@ -428,10 +504,10 @@ model_configs_bias = { ...@@ -428,10 +504,10 @@ model_configs_bias = {
max_seqlen_kv=4096, max_seqlen_kv=4096,
attn_mask_type="padding_causal", attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias", attn_bias_type="post_scale_bias",
), # skipped ),
"bias_4_4": ModelConfig( "bias_4_4": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi" 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi"
), # skipped ),
"bias_4_5": ModelConfig( "bias_4_5": ModelConfig(
2, 2,
2048, 2048,
...@@ -440,7 +516,7 @@ model_configs_bias = { ...@@ -440,7 +516,7 @@ model_configs_bias = {
max_seqlen_kv=4096, max_seqlen_kv=4096,
attn_mask_type="padding_causal", attn_mask_type="padding_causal",
attn_bias_type="alibi", attn_bias_type="alibi",
), # skipped ),
} }
...@@ -454,7 +530,7 @@ def test_dpa_bias(dtype, model_configs, model): ...@@ -454,7 +530,7 @@ def test_dpa_bias(dtype, model_configs, model):
model_configs_bias_shapes = { model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p, # test: ModelConfig(b, sq, hq, dqk)
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"), "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"),
"bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"), "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
...@@ -492,7 +568,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model): ...@@ -492,7 +568,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
model_configs_swa = { model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: ModelConfig(b, sq, hq, dqk)
"swa_1_1": ModelConfig(2, 2048, 16, 64), "swa_1_1": ModelConfig(2, 2048, 16, 64),
"swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4), "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4),
"swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096), "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096),
...@@ -532,7 +608,7 @@ def test_dpa_sliding_window(dtype, model_configs, model): ...@@ -532,7 +608,7 @@ def test_dpa_sliding_window(dtype, model_configs, model):
model_configs_alibi_slopes = { model_configs_alibi_slopes = {
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type # test: ModelConfig(b, sq, hq, dqk)
"alibi_1_0": ModelConfig( "alibi_1_0": ModelConfig(
2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla" 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla"
), ),
...@@ -586,7 +662,7 @@ qkv_layouts = [ ...@@ -586,7 +662,7 @@ qkv_layouts = [
model_configs_layout = { model_configs_layout = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: ModelConfig(b, sq, hq, dqk)
"layout_0_0": ModelConfig(2, 128, 16, 64), "layout_0_0": ModelConfig(2, 128, 16, 64),
"layout_0_1": ModelConfig( "layout_0_1": ModelConfig(
2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
...@@ -634,7 +710,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): ...@@ -634,7 +710,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
model_configs_layout_thd = { model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: ModelConfig(b, sq, hq, dqk)
"layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"), "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
"layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"), "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
"layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
...@@ -726,7 +802,6 @@ def _run_dot_product_attention( ...@@ -726,7 +802,6 @@ def _run_dot_product_attention(
is_training: bool, is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass""" """Run DotProductAttention module with one forward pass and one backward pass"""
# Set RNG and environment varables # Set RNG and environment varables
reset_rng_states() reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
...@@ -989,9 +1064,12 @@ def _run_dot_product_attention( ...@@ -989,9 +1064,12 @@ def _run_dot_product_attention(
tp_group=None, tp_group=None,
layer_number=1, layer_number=1,
attention_type=config.attn_type, attention_type=config.attn_type,
softmax_type=config.softmax_type,
).to(dtype=dtype, device="cuda") ).to(dtype=dtype, device="cuda")
if not is_training: if not is_training:
block = block.eval() block = block.eval()
if is_training and config.softmax_type != "vanilla":
block.softmax_offset.requires_grad = True
# Run a forward and backward pass # Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]: if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
...@@ -1026,12 +1104,14 @@ def _run_dot_product_attention( ...@@ -1026,12 +1104,14 @@ def _run_dot_product_attention(
) )
if is_training: if is_training:
out.backward(d_out) out.backward(d_out)
d_softmax_offset = None
if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad
if backend in ["FlashAttention", "UnfusedDotProductAttention"]: if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if is_training: if is_training:
return out, (q.grad, k.grad, v.grad) return out, (q.grad, k.grad, v.grad, d_softmax_offset)
else: else:
return out, (None, None, None) return out, (None, None, None, d_softmax_offset)
if backend == "FusedAttention": if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs: if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
...@@ -1060,18 +1140,18 @@ def _run_dot_product_attention( ...@@ -1060,18 +1140,18 @@ def _run_dot_product_attention(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
) )
if is_training: if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig) return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset)
else: else:
return out_orig, (None, None, None) return out_orig, (None, None, None, d_softmax_offset)
else: else:
if is_training: if is_training:
return out, (q.grad, k.grad, v.grad) return out, (q.grad, k.grad, v.grad, d_softmax_offset)
else: else:
return out, (None, None, None) return out, (None, None, None, d_softmax_offset)
model_configs_te_layer = { model_configs_te_layer = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: ModelConfig(b, sq, hq, dqk)
"te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"), "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"),
"te_1_1": ModelConfig( "te_1_1": ModelConfig(
4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
...@@ -1436,6 +1516,7 @@ def _run_transformer_layer( ...@@ -1436,6 +1516,7 @@ def _run_transformer_layer(
model_configs_fp8_extra_state = { model_configs_fp8_extra_state = {
# test: ModelConfig(b, sq, hq, dqk)
"large": ModelConfig(2, 128, 4, 128, num_layers=1), "large": ModelConfig(2, 128, 4, 128, num_layers=1),
} }
...@@ -1445,7 +1526,8 @@ model_configs_fp8_extra_state = { ...@@ -1445,7 +1526,8 @@ model_configs_fp8_extra_state = {
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"]) @pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype): def test_dpa_fp8_extra_state(model, dtype):
"""Test DotProductAttention module in FP8 with checkpointing"""
config = model_configs_fp8_extra_state[model] config = model_configs_fp8_extra_state[model]
# Test backend availability # Test backend availability
is_training = True is_training = True
...@@ -1459,9 +1541,9 @@ def test_sanity_attention_extra_state(model, dtype): ...@@ -1459,9 +1541,9 @@ def test_sanity_attention_extra_state(model, dtype):
if not fused_attn_supported and not flash_attn_supported: if not fused_attn_supported and not flash_attn_supported:
pytest.skip("No attention backend available.") pytest.skip("No attention backend available.")
outputs = _run_attention_extra_state(dtype, config, checkpoint=False) outputs = _run_dpa_fp8_extra_state(dtype, config, checkpoint=False)
outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True) outputs_checkpoint = _run_dpa_fp8_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_attention_extra_state( outputs_checkpoint_v1_6 = _run_dpa_fp8_extra_state(
dtype, config, mimic_v1_6=True, checkpoint=True dtype, config, mimic_v1_6=True, checkpoint=True
) )
...@@ -1483,7 +1565,8 @@ def test_sanity_attention_extra_state(model, dtype): ...@@ -1483,7 +1565,8 @@ def test_sanity_attention_extra_state(model, dtype):
) )
def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
"""Run DotProductAttention module in FP8 with checkpointing"""
steps = 10 steps = 10
path = "checkpoint.pt" path = "checkpoint.pt"
fp8_enabled = True fp8_enabled = True
...@@ -1580,7 +1663,7 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False ...@@ -1580,7 +1663,7 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False
model_configs_fp8_vs_f16 = { model_configs_fp8_vs_f16 = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: ModelConfig(b, sq, hq, dqk)
"fp8_9": ModelConfig(2, 2048, 16, 128), "fp8_9": ModelConfig(2, 2048, 16, 128),
"fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
"fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
...@@ -1600,33 +1683,6 @@ qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"] ...@@ -1600,33 +1683,6 @@ qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
def _rmse(a, b):
return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())
def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item()))
logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item()))
try:
if a.dtype != b.dtype:
a = a.to(b.dtype)
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
rmse = _rmse(a, b)
logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert rmse < rmse_tol * rmse_range, (
name_a
+ " vs "
+ name_b
+ " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
)
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
...@@ -1638,6 +1694,7 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): ...@@ -1638,6 +1694,7 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
@pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("RoPE", [True, False])
@pytest.mark.parametrize("is_training", [True, False]) @pytest.mark.parametrize("is_training", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training): def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training):
"""Test MultiHeadAttention module in FP8"""
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
...@@ -1691,7 +1748,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1691,7 +1748,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rmse_tol = 0.15 rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
if flash_attn_supported: if flash_attn_supported:
_error( compare_and_assert(
flash_attn_fwd_fp8, flash_attn_fwd_fp8,
fused_attn_fwd_f16, fused_attn_fwd_f16,
"flash_attn_fwd_fp8", "flash_attn_fwd_fp8",
...@@ -1699,8 +1756,9 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1699,8 +1756,9 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
atol, atol,
rtol, rtol,
rmse_tol, rmse_tol,
True,
) )
_error( compare_and_assert(
fused_attn_fwd_fp8, fused_attn_fwd_fp8,
fused_attn_fwd_f16, fused_attn_fwd_f16,
"fused_attn_fwd_fp8", "fused_attn_fwd_fp8",
...@@ -1708,12 +1766,13 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1708,12 +1766,13 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
atol, atol,
rtol, rtol,
rmse_tol, rmse_tol,
True,
) )
if is_training: if is_training:
for i in range(len(param_names[:1])): for i in range(len(param_names[:1])):
logging.debug("========== {:^25s} ==========".format(param_names[i])) logging.debug("========== {:^25s} ==========".format(param_names[i]))
_error( compare_and_assert(
fused_attn_bwd_fp8[i], fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i], fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]", f"fused_attn_bwd_fp8[{i}]",
...@@ -1721,10 +1780,12 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1721,10 +1780,12 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
atol, atol,
rtol, rtol,
rmse_tol, rmse_tol,
True,
) )
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training): def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training):
"""Run MultiHeadAttention module in FP8"""
reset_rng_states() reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
...@@ -1851,6 +1912,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP ...@@ -1851,6 +1912,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("is_training", [True, False]) @pytest.mark.parametrize("is_training", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
"""Test DotProductAttention module in FP8"""
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
# TODO(cyang): think of another way to verify dropout results # TODO(cyang): think of another way to verify dropout results
...@@ -1920,7 +1982,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -1920,7 +1982,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
bwd_names = ["dq", "dk", "dv"] bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
if flash_attn_supported: if flash_attn_supported:
_error( compare_and_assert(
flash_attn_fwd_fp8, flash_attn_fwd_fp8,
fused_attn_fwd_f16, fused_attn_fwd_f16,
"flash_attn_fwd_fp8", "flash_attn_fwd_fp8",
...@@ -1928,6 +1990,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -1928,6 +1990,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
atol, atol,
rtol, rtol,
rmse_tol, rmse_tol,
True,
) )
if config.dropout_p != 0.0: if config.dropout_p != 0.0:
# test cuDNN FP8 dropout # test cuDNN FP8 dropout
...@@ -1935,7 +1998,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -1935,7 +1998,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
fused_attn_fwd_fp8 == 1 fused_attn_fwd_fp8 == 1
), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s."
else: else:
_error( compare_and_assert(
fused_attn_fwd_fp8, fused_attn_fwd_fp8,
fused_attn_fwd_f16, fused_attn_fwd_f16,
"fused_attn_fwd_fp8", "fused_attn_fwd_fp8",
...@@ -1943,11 +2006,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -1943,11 +2006,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
atol, atol,
rtol, rtol,
rmse_tol, rmse_tol,
True,
) )
if is_training: if is_training:
for i, _ in enumerate(fused_attn_bwd_f16): for i, _ in enumerate(fused_attn_bwd_f16):
logging.debug("========== {:^25s} ==========".format(bwd_names[i])) logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
_error( compare_and_assert(
fused_attn_bwd_fp8[i], fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i], fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]", f"fused_attn_bwd_fp8[{i}]",
...@@ -1955,11 +2019,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -1955,11 +2019,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
atol, atol,
rtol, rtol,
rmse_tol, rmse_tol,
True,
) )
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
"""Run DotProductAttention module in FP8"""
reset_rng_states() reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
...@@ -2092,7 +2157,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): ...@@ -2092,7 +2157,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
model_configs_fp8 = { model_configs_fp8 = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: ModelConfig(b, sq, hq, dqk)
"fp8_1": ModelConfig(1, 512, 1, 64), "fp8_1": ModelConfig(1, 512, 1, 64),
"fp8_2": ModelConfig(4, 512, 16, 64), "fp8_2": ModelConfig(4, 512, 16, 64),
"fp8_3": ModelConfig(1, 2048, 1, 128), "fp8_3": ModelConfig(1, 2048, 1, 128),
...@@ -2147,7 +2212,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): ...@@ -2147,7 +2212,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
atol = 5e-1 atol = 5e-1
rtol = 5e-1 rtol = 5e-1
rmse_tol = 0.13 rmse_tol = 0.13
_error( compare_and_assert(
fused_attn_fwd_fp8, fused_attn_fwd_fp8,
unfused_attn_fwd_f16, unfused_attn_fwd_f16,
"fused_attn_fwd_fp8", "fused_attn_fwd_fp8",
...@@ -2155,8 +2220,9 @@ def test_custom_mha_fp8_vs_f16(dtype, model): ...@@ -2155,8 +2220,9 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
atol, atol,
rtol, rtol,
rmse_tol, rmse_tol,
True,
) )
_error( compare_and_assert(
fused_attn_bwd_fp8, fused_attn_bwd_fp8,
unfused_attn_bwd_f16, unfused_attn_bwd_f16,
"fused_attn_bwd_fp8", "fused_attn_bwd_fp8",
...@@ -2164,6 +2230,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): ...@@ -2164,6 +2230,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
atol, atol,
rtol, rtol,
rmse_tol, rmse_tol,
True,
) )
......
...@@ -6,6 +6,7 @@ import os ...@@ -6,6 +6,7 @@ import os
import subprocess import subprocess
import sys import sys
import pathlib import pathlib
import logging
import pytest import pytest
import torch import torch
...@@ -19,13 +20,15 @@ _current_file = pathlib.Path(__file__).resolve() ...@@ -19,13 +20,15 @@ _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent)) sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig, get_available_attention_backends from utils import ModelConfig, get_available_attention_backends
pytest_logging_level = logging.getLevelName(logging.root.level)
# Initialize RNG state # Initialize RNG state
seed = 1234 seed = 1234
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
model_configs_flash_attn = { model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
...@@ -72,6 +75,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -72,6 +75,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")
config = model_configs_flash_attn[model] config = model_configs_flash_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!") pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd": if cp_comm_type == "all_gather" and qkv_format == "thd":
...@@ -89,6 +94,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -89,6 +94,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
) )
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!") pytest.skip("MLA CP currently only support KV P2P!")
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16}
available_backends, *_ = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype],
qkv_layout="_".join([qkv_format] * 3),
)
flash_attn_supported, *_ = available_backends
if not flash_attn_supported:
pytest.skip("No attention backend available.")
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
...@@ -98,13 +112,14 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -98,13 +112,14 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
qkv_format=qkv_format, qkv_format=qkv_format,
kernel_backend="FlashAttention", kernel_backend="FlashAttention",
cp_comm_type=cp_comm_type, cp_comm_type=cp_comm_type,
log_level=pytest_logging_level,
), ),
check=True, check=True,
) )
model_configs_fused_attn = { model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig( "cp_1_2": ModelConfig(
...@@ -135,6 +150,15 @@ model_configs_fused_attn = { ...@@ -135,6 +150,15 @@ model_configs_fused_attn = {
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
), # MLA ), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA "cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
"cp_4_0": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla"
), # GQA
"cp_4_1": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
), # GQA
"cp_4_2": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
), # GQA
} }
...@@ -158,6 +182,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -158,6 +182,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("FP8 attention is only supported on sm90+!") pytest.skip("FP8 attention is only supported on sm90+!")
config = model_configs_fused_attn[model] config = model_configs_fused_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!") pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather": if qkv_format == "thd" and cp_comm_type == "all_gather":
...@@ -191,13 +217,22 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -191,13 +217,22 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("MLA CP currently only support KV P2P!") pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!") pytest.skip("MLA CP currently does not support FP8 attention!")
if dtype == "fp8" and config.softmax_type != "vanilla":
pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!")
if config.softmax_type != "vanilla" and cp_comm_type != "a2a":
pytest.skip(
"CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!"
)
if config.softmax_type != "vanilla" and qkv_format == "thd":
pytest.skip(
"CP implementation does not support qkv_format=thd for non-vanilla softmax types!"
)
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
available_backends, _, fused_attn_backends = get_available_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtypes[dtype], qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn,
qkv_layout="_".join([qkv_format] * 3), qkv_layout="_".join([qkv_format] * 3),
window_size=config.window_size,
context_parallel=True,
) )
_, fused_attn_supported, _ = available_backends _, fused_attn_supported, _ = available_backends
if not fused_attn_supported: if not fused_attn_supported:
...@@ -212,6 +247,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -212,6 +247,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
kernel_backend="FusedAttention", kernel_backend="FusedAttention",
cp_comm_type=cp_comm_type, cp_comm_type=cp_comm_type,
fp8_mha=fp8_mha, fp8_mha=fp8_mha,
log_level=pytest_logging_level,
), ),
check=True, check=True,
) )
...@@ -469,7 +469,6 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g ...@@ -469,7 +469,6 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=False, pad_between_seqs=False,
is_training=False, is_training=False,
fp8=is_fp8, fp8=is_fp8,
......
...@@ -20,6 +20,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import ( ...@@ -20,6 +20,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
get_attention_backend, get_attention_backend,
AttentionParams, AttentionParams,
AttentionLogging, AttentionLogging,
check_set_window_size,
) )
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
...@@ -137,6 +138,31 @@ def reset_rng_states() -> None: ...@@ -137,6 +138,31 @@ def reset_rng_states() -> None:
torch.cuda.set_rng_state(cuda_rng_state) torch.cuda.set_rng_state(cuda_rng_state)
def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8):
if not is_fp8:
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
return
try:
if a.dtype != b.dtype:
a = a.to(b.dtype)
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
rmse = torch.sqrt((a - b).square().mean()).item()
logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert rmse < rmse_tol * rmse_range, (
name_a
+ " vs "
+ name_b
+ " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
)
class ModelConfig: class ModelConfig:
def __init__( def __init__(
self, self,
...@@ -147,12 +173,15 @@ class ModelConfig: ...@@ -147,12 +173,15 @@ class ModelConfig:
max_seqlen_kv: int = None, max_seqlen_kv: int = None,
num_gqa_groups: int = None, num_gqa_groups: int = None,
head_dim_v: int = None, head_dim_v: int = None,
softmax_type: str = "vanilla",
dropout_p: float = 0.0, dropout_p: float = 0.0,
attn_mask_type: str = "no_mask", attn_mask_type: str = "no_mask",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
alibi_type: str = "none", alibi_type: str = "none",
bias_shape: str = "1hss", bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
context_parallel: bool = False,
cp_comm_type: str = "p2p",
total_requests: int = None, total_requests: int = None,
max_ctx_len: int = None, max_ctx_len: int = None,
num_layers: int = 1, num_layers: int = 1,
...@@ -171,13 +200,16 @@ class ModelConfig: ...@@ -171,13 +200,16 @@ class ModelConfig:
self.kv_channels = (self.head_dim_qk, self.head_dim_v) self.kv_channels = (self.head_dim_qk, self.head_dim_v)
self.hidden_size = self.num_heads * self.head_dim_qk self.hidden_size = self.num_heads * self.head_dim_qk
self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v
self.softmax_type = softmax_type
self.dropout_p = dropout_p self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type self.alibi_type = alibi_type
self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross" self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross"
self.bias_shape = bias_shape self.bias_shape = bias_shape
self.window_size = window_size self.window_size = check_set_window_size(self.attn_mask_type, window_size)
self.context_parallel = context_parallel
self.cp_comm_type = cp_comm_type
self.total_requests = total_requests self.total_requests = total_requests
self.max_ctx_len = max_ctx_len self.max_ctx_len = max_ctx_len
self.num_layers = num_layers self.num_layers = num_layers
...@@ -198,9 +230,7 @@ def get_available_attention_backends( ...@@ -198,9 +230,7 @@ def get_available_attention_backends(
config: ModelConfig, config: ModelConfig,
qkv_dtype: torch.dtype, qkv_dtype: torch.dtype,
qkv_layout: str, qkv_layout: str,
window_size: Tuple[int, int] = (-1, -1),
pad_between_seqs: bool = False, pad_between_seqs: bool = False,
context_parallel: bool = False,
deterministic: bool = False, deterministic: bool = False,
fp8: bool = False, fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None, fp8_meta: Optional[Dict[str, Any]] = None,
...@@ -250,19 +280,21 @@ def get_available_attention_backends( ...@@ -250,19 +280,21 @@ def get_available_attention_backends(
head_dim_qk=config.head_dim_qk, head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v, head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type, attn_mask_type=config.attn_mask_type,
window_size=window_size, window_size=config.window_size,
alibi_slopes_shape=alibi_slopes_shape, alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape, core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad, core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p, attention_dropout=config.dropout_p,
context_parallel=context_parallel, context_parallel=config.context_parallel,
cp_comm_type=config.cp_comm_type,
deterministic=deterministic, deterministic=deterministic,
fp8=fp8, fp8=fp8,
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
is_training=is_training, is_training=is_training,
inference_params=inference_params, inference_params=inference_params,
softmax_type=config.softmax_type,
) )
( (
use_flash_attention, use_flash_attention,
......
...@@ -135,9 +135,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { ...@@ -135,9 +135,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
// select a backend for fused attention // select a backend for fused attention
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
using namespace transformer_engine; using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device(); const int device_id = cuda::current_device();
...@@ -175,7 +176,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -175,7 +176,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// TODO (cyang): add is_training to nvte_get_fused_attn_backend // TODO (cyang): add is_training to nvte_get_fused_attn_backend
// sm90: fwd d<=256, bwd d=128 only // sm90: fwd d<=256, bwd d=128 only
// sm100: fwd d<=128, bwd d<=128 // sm100: fwd d<=128, bwd d<=128
((sm_arch_ < 100 && head_dim_qk <= 256 && head_dim_v <= 256) || ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) ||
(sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) ||
(sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) &&
head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
...@@ -183,7 +185,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -183,7 +185,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
!requires_64bit_ragged_offset && !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) &&
// 9.10.0: known bugs with SDPA FP8 // 9.10.0: known bugs with SDPA FP8
(cudnn_runtime_version != 91000)) { (cudnn_runtime_version != 91000)) {
if (cudnn_runtime_version >= 8900) { if (cudnn_runtime_version >= 8900) {
...@@ -213,7 +215,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -213,7 +215,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) &&
((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) &&
!requires_64bit_ragged_offset) { !requires_64bit_ragged_offset &&
(softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) {
flag_m512 = true; flag_m512 = true;
} }
if ( if (
...@@ -363,7 +366,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -363,7 +366,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// check 64-bit ragged offset support // check 64-bit ragged offset support
(supported_ragged_offset_size) && (supported_ragged_offset_size) &&
// 9.10.0/9.10.1: known bugs with SDPA F16 // 9.10.0/9.10.1: known bugs with SDPA F16
(cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) { (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001) &&
// softmax type
// pre-9.13.1: vanilla
// 9.13.1+: vanilla, off-by-one, learnable
(cudnn_runtime_version >= 91301 ||
(cudnn_runtime_version < 91301 &&
softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))) {
flag_arb = true; flag_arb = true;
} }
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
...@@ -405,14 +414,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -405,14 +414,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
} }
// NVTE fused attention FWD with packed QKV // NVTE fused attention FWD with packed QKV
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor rng_state, size_t max_seqlen, bool is_training, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, size_t max_seqlen, bool is_training, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
NVTETensor workspace, cudaStream_t stream) { int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -421,6 +432,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -421,6 +432,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); const Tensor *input_rng_state = convertNVTETensorCheck(rng_state);
const Tensor *input_QKV = convertNVTETensorCheck(QKV); const Tensor *input_QKV = convertNVTETensorCheck(QKV);
const Tensor *input_Bias = convertNVTETensorCheck(Bias); const Tensor *input_Bias = convertNVTETensorCheck(Bias);
const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset);
Tensor *input_output_S = convertNVTETensorCheck(S); Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = convertNVTETensorCheck(O); Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = convertNVTETensor(workspace); Tensor *wkspace = convertNVTETensor(workspace);
...@@ -447,8 +459,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -447,8 +459,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype); const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -463,9 +475,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -463,9 +475,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O, attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, input_Bias,
Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded,
stream, handle); input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -487,10 +499,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -487,10 +499,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens, NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens_padded, size_t max_seqlen, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, size_t max_seqlen, float attn_scale, float dropout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream) { bool deterministic, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
...@@ -505,6 +518,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -505,6 +518,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
Tensor *input_output_dP = convertNVTETensorCheck(dP); Tensor *input_output_dP = convertNVTETensorCheck(dP);
Tensor *output_dQKV = convertNVTETensorCheck(dQKV); Tensor *output_dQKV = convertNVTETensorCheck(dQKV);
Tensor *output_dBias = convertNVTETensorCheck(dBias); Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset);
Tensor *wkspace = convertNVTETensor(workspace); Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_QKV->data.shape.size(); auto ndim = input_QKV->data.shape.size();
...@@ -529,8 +543,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -529,8 +543,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype); const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h,
max_seqlen, d, d, window_size_left, window_size_right); max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -543,19 +557,22 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -543,19 +557,22 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); size_t i = 0;
Tensor *input_Bias, *input_rng_state; Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_Bias, *input_SoftmaxOffset;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); }
} else { if (softmax_type != NVTE_VANILLA_SOFTMAX) {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
} }
fused_attn_arbitrary_seqlen_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO, softmax_type, window_size_left, window_size_right, deterministic, input_QKV, input_O,
input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQKV, output_dBias,
input_rng_state, wkspace, stream, handle); output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace,
stream, handle);
#else #else
const char *err_msg = const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention " "cuDNN 8.9.0 is required for BF16/FP16 fused attention "
...@@ -580,14 +597,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -580,14 +597,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
} }
// NVTE fused attention FWD with packed KV // NVTE fused attention FWD with packed KV
void nvte_fused_attn_fwd_kvpacked( void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
cudaStream_t stream) { NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
...@@ -600,6 +618,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -600,6 +618,7 @@ void nvte_fused_attn_fwd_kvpacked(
const Tensor *input_Q = convertNVTETensorCheck(Q); const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_KV = convertNVTETensorCheck(KV); const Tensor *input_KV = convertNVTETensorCheck(KV);
const Tensor *input_Bias = convertNVTETensorCheck(Bias); const Tensor *input_Bias = convertNVTETensorCheck(Bias);
const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset);
Tensor *input_output_S = convertNVTETensorCheck(S); Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = convertNVTETensorCheck(O); Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = convertNVTETensor(workspace); Tensor *wkspace = convertNVTETensor(workspace);
...@@ -660,8 +679,8 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -660,8 +679,8 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -677,10 +696,11 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -677,10 +696,11 @@ void nvte_fused_attn_fwd_kvpacked(
fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left,
input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, output_O,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state,
wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -702,12 +722,12 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -702,12 +722,12 @@ void nvte_fused_attn_fwd_kvpacked(
void nvte_fused_attn_bwd_kvpacked( void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
cudaStream_t stream) { int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
...@@ -723,6 +743,7 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -723,6 +743,7 @@ void nvte_fused_attn_bwd_kvpacked(
Tensor *output_dQ = convertNVTETensorCheck(dQ); Tensor *output_dQ = convertNVTETensorCheck(dQ);
Tensor *output_dKV = convertNVTETensorCheck(dKV); Tensor *output_dKV = convertNVTETensorCheck(dKV);
Tensor *output_dBias = convertNVTETensorCheck(dBias); Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset);
Tensor *wkspace = convertNVTETensor(workspace); Tensor *wkspace = convertNVTETensor(workspace);
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
...@@ -755,8 +776,8 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -755,8 +776,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -770,20 +791,23 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -770,20 +791,23 @@ void nvte_fused_attn_bwd_kvpacked(
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8903) #if (CUDNN_VERSION >= 8903)
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); size_t i = 0;
Tensor *input_Bias, *input_rng_state; Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_Bias, *input_SoftmaxOffset;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); }
} else { if (softmax_type != NVTE_VANILLA_SOFTMAX) {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
} }
fused_attn_arbitrary_seqlen_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic,
input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias, input_Q, input_KV, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, output_dKV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
#else #else
const char *err_msg = const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention " "cuDNN 8.9.3 is required for BF16/FP16 fused attention "
...@@ -809,16 +833,17 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -809,16 +833,17 @@ void nvte_fused_attn_bwd_kvpacked(
} }
// NVTE fused attention FWD with separate Q, K and V // NVTE fused attention FWD with separate Q, K and V
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
cudaStream_t stream) { int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd); NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
...@@ -832,6 +857,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -832,6 +857,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const Tensor *input_K = convertNVTETensorCheck(K); const Tensor *input_K = convertNVTETensorCheck(K);
const Tensor *input_V = convertNVTETensorCheck(V); const Tensor *input_V = convertNVTETensorCheck(V);
const Tensor *input_Bias = convertNVTETensorCheck(Bias); const Tensor *input_Bias = convertNVTETensorCheck(Bias);
const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset);
Tensor *input_output_S = convertNVTETensorCheck(S); Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = convertNVTETensorCheck(O); Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = convertNVTETensor(workspace); Tensor *wkspace = convertNVTETensor(workspace);
...@@ -886,8 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -886,8 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -903,10 +929,11 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -903,10 +929,11 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left,
input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, window_size_right, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state,
wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -928,14 +955,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -928,14 +955,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, int64_t window_size_left, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_right, bool deterministic, NVTETensor workspace, int64_t window_size_left, int64_t window_size_right, bool deterministic,
cudaStream_t stream) { NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd); NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
...@@ -953,6 +981,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -953,6 +981,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
Tensor *output_dK = convertNVTETensorCheck(dK); Tensor *output_dK = convertNVTETensorCheck(dK);
Tensor *output_dV = convertNVTETensorCheck(dV); Tensor *output_dV = convertNVTETensorCheck(dV);
Tensor *output_dBias = convertNVTETensorCheck(dBias); Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset);
Tensor *wkspace = convertNVTETensor(workspace); Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_Q->data.shape.size(); auto ndim = input_Q->data.shape.size();
...@@ -978,8 +1007,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -978,8 +1007,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -993,19 +1022,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -993,19 +1022,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); size_t i = 0;
Tensor *input_Bias, *input_rng_state; Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_Bias, *input_SoftmaxOffset;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); }
} else { if (softmax_type != NVTE_VANILLA_SOFTMAX) {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
} }
fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right,
input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias,
output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias,
output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else #else
const char *err_msg = const char *err_msg =
......
...@@ -54,10 +54,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -54,10 +54,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k,
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK,
void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrSoftmaxStats,
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -75,6 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -75,6 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
is_causal = true; is_causal = true;
is_bottom_right = false; is_bottom_right = false;
} }
bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX);
bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_dropout = (is_training && dropout_probability != 0.0f);
NVTE_QKV_Format q_format = nvte_get_q_format(layout); NVTE_QKV_Format q_format = nvte_get_q_format(layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(layout);
...@@ -98,8 +100,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -98,8 +100,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
s_q = is_ragged_q ? max_t_q : s_q; s_q = is_ragged_q ? max_t_q : s_q;
s_kv = is_ragged_kv ? max_t_kv : s_kv; s_kv = is_ragged_kv ? max_t_kv : s_kv;
} }
const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
try { try {
FADescriptor_v1 descriptor{b, FADescriptor_v1 descriptor{b,
h, h,
...@@ -122,6 +124,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -122,6 +124,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout, layout,
bias_type, bias_type,
mask_type, mask_type,
softmax_type,
window_size_left, window_size_left,
window_size_right, window_size_right,
true, true,
...@@ -138,6 +141,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -138,6 +141,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // O std::shared_ptr<fe::graph::Tensor_attributes>, // O
std::shared_ptr<fe::graph::Tensor_attributes>, // Stats std::shared_ptr<fe::graph::Tensor_attributes>, // Stats
std::shared_ptr<fe::graph::Tensor_attributes>, // bias std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // softmax_offset
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // page_table_k std::shared_ptr<fe::graph::Tensor_attributes>, // page_table_k
...@@ -168,7 +172,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -168,7 +172,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_intermediate_data_type(fe::DataType_t::FLOAT) .set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT); .set_compute_data_type(fe::DataType_t::FLOAT);
std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale; std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale, softmax_offset;
std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv; std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> page_table_k, page_table_v; std::shared_ptr<fe::graph::Tensor_attributes> page_table_k, page_table_v;
std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o, std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o,
...@@ -302,6 +306,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -302,6 +306,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset);
} }
if (is_softmax_offset) {
softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("softmax_offset")
.set_dim({1, h, 1, 1})
.set_stride({h, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
sdpa_options.set_sink_token(softmax_offset);
}
auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options); auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options);
std::vector<int64_t> o_stride(4); std::vector<int64_t> o_stride(4);
...@@ -338,6 +351,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -338,6 +351,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O); key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O);
auto Stats_tuple = std::make_tuple(Stats); auto Stats_tuple = std::make_tuple(Stats);
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto softmax_offset_tuple =
is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr);
auto padding_tuple = auto padding_tuple =
is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto page_table_tuple = is_paged_kv ? std::make_tuple(page_table_k, page_table_v) auto page_table_tuple = is_paged_kv ? std::make_tuple(page_table_k, page_table_v)
...@@ -358,17 +373,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -358,17 +373,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
auto return_tuple = std::tuple_cat( auto return_tuple =
std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple, std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple,
page_table_tuple, offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple); softmax_offset_tuple, padding_tuple, page_table_tuple, offset_qo_tuple,
offset_kv_tuple, offset_s_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple}); cache.insert({descriptor, return_tuple});
return return_tuple; return return_tuple;
}; };
auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, page_table_k, page_table_v, auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, softmax_offset, seq_q, seq_kv,
offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = page_table_k, page_table_v, offset_q, offset_o, offset_k, offset_v, offset_stats,
get_graph(sdpa_f16_fprop_cache, descriptor); dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor);
// Exit to request upper level API to allocate memory if needed // Exit to request upper level API to allocate memory if needed
// n.b. Care should be taken to align each of the added worksapce tensors to their type. // n.b. Care should be taken to align each of the added worksapce tensors to their type.
...@@ -473,6 +489,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -473,6 +489,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_seed] = devPtrDropoutSeed;
variant_pack[dropout_offset] = devPtrDropoutOffset; variant_pack[dropout_offset] = devPtrDropoutOffset;
} }
if (is_softmax_offset) {
variant_pack[softmax_offset] = devPtrSoftmaxOffset;
}
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException &e) { } catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what()); NVTE_ERROR(e.what());
...@@ -483,14 +504,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -483,14 +504,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose, int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ,
void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats,
void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed,
void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
cudaStream_t stream, cudnnHandle_t handle) { void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
...@@ -506,6 +527,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -506,6 +527,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
is_causal = true; is_causal = true;
is_bottom_right = false; is_bottom_right = false;
} }
bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX);
bool is_dropout = (dropout_probability != 0.0f); bool is_dropout = (dropout_probability != 0.0f);
NVTE_QKV_Format q_format = nvte_get_q_format(layout); NVTE_QKV_Format q_format = nvte_get_q_format(layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(layout);
...@@ -558,6 +580,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -558,6 +580,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
layout, layout,
bias_type, bias_type,
mask_type, mask_type,
softmax_type,
window_size_left, window_size_left,
window_size_right, window_size_right,
deterministic, deterministic,
...@@ -579,6 +602,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -579,6 +602,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // dV std::shared_ptr<fe::graph::Tensor_attributes>, // dV
std::shared_ptr<fe::graph::Tensor_attributes>, // bias std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // dBias std::shared_ptr<fe::graph::Tensor_attributes>, // dBias
std::shared_ptr<fe::graph::Tensor_attributes>, // softmax_offset
std::shared_ptr<fe::graph::Tensor_attributes>, // d_softmax_offset
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_q std::shared_ptr<fe::graph::Tensor_attributes>, // offset_q
...@@ -608,7 +633,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -608,7 +633,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_compute_data_type(fe::DataType_t::FLOAT); .set_compute_data_type(fe::DataType_t::FLOAT);
std::shared_ptr<fe::graph::Tensor_attributes> q, k, v, o, dO, stats, attn_scale; std::shared_ptr<fe::graph::Tensor_attributes> q, k, v, o, dO, stats, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> bias, dBias, seq_q, seq_kv; std::shared_ptr<fe::graph::Tensor_attributes> bias, dBias, softmax_offset, d_softmax_offset,
seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o, std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o,
offset_stats; offset_stats;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset; std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
...@@ -771,6 +797,21 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -771,6 +797,21 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset);
} }
if (is_softmax_offset) {
softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("softmax_offset")
.set_dim({1, h, 1, 1})
.set_stride({h, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
sdpa_backward_options.set_sink_token(softmax_offset);
d_softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("d_softmax_offset")
.set_dim({1, h, 1, 1})
.set_stride({h, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
sdpa_backward_options.set_dsink_token(d_softmax_offset);
}
auto [dQ, dK, dV] = mha_graph->sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options); auto [dQ, dK, dV] = mha_graph->sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options);
dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride); dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride);
...@@ -796,6 +837,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -796,6 +837,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>> // dV std::shared_ptr<fe::graph::Tensor_attributes>> // dV
key_tensors_tuple = std::make_tuple(q, k, v, o, dO, stats, attn_scale, dQ, dK, dV); key_tensors_tuple = std::make_tuple(q, k, v, o, dO, stats, attn_scale, dQ, dK, dV);
auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr);
auto softmax_offset_tuple = is_softmax_offset
? std::make_tuple(softmax_offset, d_softmax_offset)
: std::make_tuple(nullptr, nullptr);
auto padding_tuple = auto padding_tuple =
is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto offset_qo_tuple = auto offset_qo_tuple =
...@@ -814,17 +858,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -814,17 +858,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
auto return_tuple = auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple,
std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple, softmax_offset_tuple, padding_tuple, offset_qo_tuple,
offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple); offset_kv_tuple, offset_s_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple}); cache.insert({descriptor, return_tuple});
return return_tuple; return return_tuple;
}; };
auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv, auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, softmax_offset,
offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = d_softmax_offset, seq_q, seq_kv, offset_q, offset_o, offset_k, offset_v, offset_stats,
get_graph(sdpa_f16_bprop_cache, descriptor); dropout_seed, dropout_offset] = get_graph(sdpa_f16_bprop_cache, descriptor);
// Exit to request upper level API to allocate memory if needed // Exit to request upper level API to allocate memory if needed
// n.b. Care should be taken to align each of the added worksapce tensors to their type. // n.b. Care should be taken to align each of the added worksapce tensors to their type.
...@@ -938,6 +982,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -938,6 +982,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
variant_pack[dropout_offset] = devPtrDropoutOffset; variant_pack[dropout_offset] = devPtrDropoutOffset;
} }
if (is_softmax_offset) {
variant_pack[softmax_offset] = devPtrSoftmaxOffset;
variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset;
}
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException &e) { } catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what()); NVTE_ERROR(e.what());
...@@ -949,8 +998,9 @@ using namespace transformer_engine::fused_attn; ...@@ -949,8 +998,9 @@ using namespace transformer_engine::fused_attn;
void fused_attn_arbitrary_seqlen_fwd_qkvpacked( void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -977,6 +1027,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -977,6 +1027,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
bias_b = input_Bias->data.shape[0]; bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1]; bias_h = input_Bias->data.shape[1];
} }
void *devPtrSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
}
void *devPtrO = output_O->data.dptr; void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr; void *devPtrS = nullptr;
...@@ -990,53 +1044,50 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -990,53 +1044,50 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
max_tokens = get_max_tokens(num_tokens); max_tokens = get_max_tokens(num_tokens);
} }
size_t i = 0;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion(); const auto cudnn_runtime_version = cudnnGetVersion();
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3; Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr; output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen};
output_bias->data.dtype = QKV_type; output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} }
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrS = output_S->data.dptr; Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_softmax_offset->data.dptr = nullptr;
output_rng_state->data.dptr = rng_state->data.dptr; output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1};
} else if (Aux_CTX_Tensors->size == 3) { output_softmax_offset->data.dtype = DType::kFloat32;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); }
Aux_CTX_Tensors->size = i;
} else if (Aux_CTX_Tensors->size >= 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
output_bias->data.dptr = devPtrBias; Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = devPtrBias;
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = devPtrSoftmaxOffset;
}
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
} }
...@@ -1050,11 +1101,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -1050,11 +1101,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
fused_attn_arbitrary_seqlen_fwd_impl( fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim,
max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS,
devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr,
devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type),
handle); workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1074,9 +1125,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -1074,9 +1125,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1122,6 +1174,12 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( ...@@ -1122,6 +1174,12 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
void *devPtrSoftmaxStats = nullptr; void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr; devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrSoftmaxOffset = nullptr;
void *devPtrdSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr;
}
void *devPtrCuSeqlens = cu_seqlens->data.dptr; void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr;
...@@ -1135,11 +1193,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( ...@@ -1135,11 +1193,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
fused_attn_arbitrary_seqlen_bwd_impl( fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim,
max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, devPtrK, bias_type, mask_type, softmax_type, window_size_left, window_size_right, deterministic,
devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset,
devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed,
devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
&workspace_size, stream, handle); get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1161,12 +1219,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1161,12 +1219,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
...@@ -1192,6 +1250,10 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1192,6 +1250,10 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
bias_b = input_Bias->data.shape[0]; bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1]; bias_h = input_Bias->data.shape[1];
} }
void *devPtrSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
}
void *devPtrO = output_O->data.dptr; void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr; void *devPtrS = nullptr;
...@@ -1216,53 +1278,50 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1216,53 +1278,50 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
max_tokens_kv = get_max_tokens(num_tokens_kv); max_tokens_kv = get_max_tokens(num_tokens_kv);
} }
size_t i = 0;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion(); const auto cudnn_runtime_version = cudnnGetVersion();
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3; Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr; output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type; output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} }
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrS = output_S->data.dptr; Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_softmax_offset->data.dptr = nullptr;
output_rng_state->data.dptr = rng_state->data.dptr; output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1};
} else if (Aux_CTX_Tensors->size == 3) { output_softmax_offset->data.dtype = DType::kFloat32;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); }
Aux_CTX_Tensors->size = i;
} else if (Aux_CTX_Tensors->size >= 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
output_bias->data.dptr = devPtrBias; Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = devPtrBias;
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = devPtrSoftmaxOffset;
}
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
} }
...@@ -1277,11 +1336,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1277,11 +1336,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim,
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
&workspace_size, stream, handle); get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1302,10 +1361,11 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -1302,10 +1361,11 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
...@@ -1359,6 +1419,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -1359,6 +1419,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void *devPtrSoftmaxStats = nullptr; void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr; devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrSoftmaxOffset = nullptr;
void *devPtrdSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr;
}
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
...@@ -1374,9 +1440,10 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -1374,9 +1440,10 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd_impl( fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle); workspace->data.dptr, &workspace_size, stream, handle);
...@@ -1401,12 +1468,13 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1401,12 +1468,13 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
...@@ -1425,6 +1493,10 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1425,6 +1493,10 @@ void fused_attn_arbitrary_seqlen_fwd(
bias_b = input_Bias->data.shape[0]; bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1]; bias_h = input_Bias->data.shape[1];
} }
void *devPtrSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
}
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
...@@ -1446,53 +1518,50 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1446,53 +1518,50 @@ void fused_attn_arbitrary_seqlen_fwd(
max_tokens_kv = get_max_tokens(num_tokens_kv); max_tokens_kv = get_max_tokens(num_tokens_kv);
} }
size_t i = 0;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion(); const auto cudnn_runtime_version = cudnnGetVersion();
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3; Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr; output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type; output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} }
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrS = output_S->data.dptr; Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_softmax_offset->data.dptr = nullptr;
output_rng_state->data.dptr = rng_state->data.dptr; output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1};
} else if (Aux_CTX_Tensors->size == 3) { output_softmax_offset->data.dtype = DType::kFloat32;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); }
Aux_CTX_Tensors->size = i;
} else if (Aux_CTX_Tensors->size >= 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
output_bias->data.dptr = devPtrBias; Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = devPtrBias;
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = devPtrSoftmaxOffset;
}
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
} }
...@@ -1507,11 +1576,11 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1507,11 +1576,11 @@ void fused_attn_arbitrary_seqlen_fwd(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
&workspace_size, stream, handle); get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1532,13 +1601,14 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -1532,13 +1601,14 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
cudaStream_t stream, cudnnHandle_t handle) { const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr; void *devPtrQ = input_Q->data.dptr;
...@@ -1577,6 +1647,12 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -1577,6 +1647,12 @@ void fused_attn_arbitrary_seqlen_bwd(
void *devPtrdV = output_dV->data.dptr; void *devPtrdV = output_dV->data.dptr;
void *devPtrSoftmaxStats = nullptr; void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr; devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrSoftmaxOffset = nullptr;
void *devPtrdSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr;
}
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
...@@ -1592,9 +1668,10 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -1592,9 +1668,10 @@ void fused_attn_arbitrary_seqlen_bwd(
fused_attn_arbitrary_seqlen_bwd_impl( fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle); workspace->data.dptr, &workspace_size, stream, handle);
......
...@@ -21,17 +21,19 @@ namespace transformer_engine { ...@@ -21,17 +21,19 @@ namespace transformer_engine {
void fused_attn_arbitrary_seqlen_fwd_qkvpacked( void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
...@@ -41,21 +43,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -41,21 +43,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle); cudaStream_t stream, cudnnHandle_t handle);
...@@ -66,24 +69,26 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -66,24 +69,26 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd( void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
cudaStream_t stream, cudnnHandle_t handle); const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900 #endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -1695,6 +1695,7 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1695,6 +1695,7 @@ void fused_attn_fp8_fwd_impl_v1(
layout, layout,
bias_type, bias_type,
mask_type, mask_type,
NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0, 0,
0, 0,
true, true,
...@@ -2000,6 +2001,7 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -2000,6 +2001,7 @@ void fused_attn_fp8_bwd_impl_v1(
layout, layout,
bias_type, bias_type,
mask_type, mask_type,
NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0, 0,
0, 0,
false, false,
......
...@@ -107,6 +107,7 @@ struct FADescriptor_v1 { ...@@ -107,6 +107,7 @@ struct FADescriptor_v1 {
NVTE_QKV_Layout layout; NVTE_QKV_Layout layout;
NVTE_Bias_Type bias_type; NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type; NVTE_Mask_Type mask_type;
NVTE_Softmax_Type softmax_type;
std::int64_t window_size_left; std::int64_t window_size_left;
std::int64_t window_size_right; std::int64_t window_size_right;
bool deterministic; bool deterministic;
...@@ -116,14 +117,15 @@ struct FADescriptor_v1 { ...@@ -116,14 +117,15 @@ struct FADescriptor_v1 {
bool operator<(const FADescriptor_v1 &rhs) const { bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < window_size_left, window_size_right, deterministic, bias_type, fwd_tensor_type,
bwd_tensor_type) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
rhs.bwd_tensor_type); rhs.fwd_tensor_type, rhs.bwd_tensor_type);
} }
}; };
......
...@@ -124,6 +124,24 @@ enum NVTE_Mask_Type { ...@@ -124,6 +124,24 @@ enum NVTE_Mask_Type {
NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5, NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5,
}; };
/*! \enum NVTE_Softmax_Type
* \brief Attention softmax types as described in
* Efficient Streaming Language Models with Attention Sinks (https://arxiv.org/pdf/2309.17453v3).
* For a given attention score S = Q*K^T, different softmax types perform different operations on S,
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* where alpha is a learnable parameter in shape [H].
*/
enum NVTE_Softmax_Type {
/*! Vanilla softmax */
NVTE_VANILLA_SOFTMAX = 0,
/*! Off-by-one softmax */
NVTE_OFF_BY_ONE_SOFTMAX = 1,
/*! Learnable softmax */
NVTE_LEARNABLE_SOFTMAX = 2,
};
/*! \enum NVTE_Fused_Attn_Backend /*! \enum NVTE_Fused_Attn_Backend
* \brief Fused attention backends * \brief Fused attention backends
*/ */
...@@ -178,6 +196,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); ...@@ -178,6 +196,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] qkv_layout The layout of Tensors Q, K, V. * \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type. * \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type. * \param[in] attn_mask_type The attention mask type.
* \param[in] softmax_type The attention softmax type.
* \param[in] dropout The dropout probability. * \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q. * \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V. * \param[in] num_gqa_groups The number of heads in K, V.
...@@ -190,9 +209,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); ...@@ -190,9 +209,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
*/ */
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right);
/*! \brief Compute dot product attention with packed QKV input. /*! \brief Compute dot product attention with packed QKV input.
* *
...@@ -224,6 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -224,6 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* *
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD. * \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor. * \param[in,out] S The S tensor.
* \param[out] O The output O tensor. * \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
...@@ -239,19 +260,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -239,19 +260,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, void nvte_fused_attn_fwd_qkvpacked(
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor rng_state, size_t max_seqlen, bool is_training, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input. /*! \brief Compute the backward of the dot product attention with packed QKV input.
* *
...@@ -284,6 +305,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -284,6 +305,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* e.g. M, ZInv, rng_state. * e.g. M, ZInv, rng_state.
* \param[out] dQKV The gradient of the QKV tensor. * \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1]. * \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing, * \param[in] max_seqlen Max sequence length used for computing,
...@@ -293,6 +315,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -293,6 +315,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
...@@ -302,10 +325,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -302,10 +325,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens, NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens_padded, size_t max_seqlen, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, size_t max_seqlen, float attn_scale, float dropout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream); bool deterministic, NVTETensor workspace, cudaStream_t stream);
...@@ -340,6 +364,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -340,6 +364,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] Q The Q tensor, in HD layouts. * \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in 2HD or H2D layouts. * \param[in] KV The KV tensor, in 2HD or H2D layouts.
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor. * \param[in,out] S The S tensor.
* \param[out] O The output O tensor. * \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
...@@ -361,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -361,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
...@@ -368,13 +394,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -368,13 +394,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd_kvpacked( void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input. /*! \brief Compute the backward of the dot product attention with packed KV input.
* *
...@@ -409,6 +437,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -409,6 +437,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[out] dQ The gradient of the Q tensor. * \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor. * \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
...@@ -422,6 +451,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -422,6 +451,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
...@@ -431,12 +461,12 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -431,12 +461,12 @@ void nvte_fused_attn_fwd_kvpacked(
void nvte_fused_attn_bwd_kvpacked( void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
cudaStream_t stream); int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V. /*! \brief Compute dot product attention with separate Q, K and V.
* *
...@@ -473,6 +503,7 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -473,6 +503,7 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] K The K tensor. * \param[in] K The K tensor.
* \param[in] V The V tensor. * \param[in] V The V tensor.
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor. * \param[in,out] S The S tensor.
* \param[out] O The output O tensor. * \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
...@@ -494,22 +525,24 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -494,22 +525,24 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] qkv_layout QKV tensors' layout. * \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
cudaStream_t stream); int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V. /*! \brief Compute the backward of the dot product attention with separate Q, K and V.
* *
...@@ -549,6 +582,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -549,6 +582,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[out] dK The gradient of the K tensor. * \param[out] dK The gradient of the K tensor.
* \param[out] dV The gradient of the V tensor. * \param[out] dV The gradient of the V tensor.
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
...@@ -562,6 +596,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -562,6 +596,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[in] qkv_layout QKV tensors' layout. * \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
...@@ -571,14 +606,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -571,14 +606,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, int64_t window_size_left, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_right, bool deterministic, NVTETensor workspace, int64_t window_size_left, int64_t window_size_right, bool deterministic,
cudaStream_t stream); NVTETensor workspace, cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset. /*! \brief Update the RNG state with the seed and calculated offset.
* *
......
...@@ -36,6 +36,10 @@ ...@@ -36,6 +36,10 @@
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_Softmax_Type>(m, "NVTE_Softmax_Type", pybind11::module_local()) \
.value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) \
.value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX) \
.value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX); \
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \ pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \
......
...@@ -18,10 +18,11 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy ...@@ -18,10 +18,11 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy
size_t q_max_seqlen, size_t kv_max_seqlen, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t qk_head_dim, size_t v_head_dim, size_t qk_head_dim, size_t v_head_dim,
int64_t window_size_left, int64_t window_size_right) { int64_t window_size_left, int64_t window_size_right) {
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads,
kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
return backend; return backend;
} }
...@@ -146,6 +147,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -146,6 +147,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64); auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32); auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
auto dummy_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
...@@ -172,28 +176,30 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -172,28 +176,30 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen"); NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen");
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor, ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_right, query_workspace_tensor.data(), nullptr); window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, query_workspace_tensor.data(), nullptr); window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right,
query_workspace_tensor.data(), nullptr);
} else { } else {
NVTE_ERROR("Unsupported QKVLayout."); NVTE_ERROR("Unsupported QKVLayout.");
} }
...@@ -262,10 +268,15 @@ static void FusedAttnForwardImpl( ...@@ -262,10 +268,15 @@ static void FusedAttnForwardImpl(
/* Prepare RNG state */ /* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto dummy_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */ /* Auxiliary tensors (to be propagated to the backward pass later) */
...@@ -280,12 +291,12 @@ static void FusedAttnForwardImpl( ...@@ -280,12 +291,12 @@ static void FusedAttnForwardImpl(
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), nvte_fused_attn_fwd_qkvpacked(
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_max_seqlen, is_training, scaling_factor, dropout_probability, q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training,
qkv_layout, bias_type, mask_type, window_size_left, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_right, workspace_tensor.data(), stream); window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape = auto kv_shape =
...@@ -293,12 +304,13 @@ static void FusedAttnForwardImpl( ...@@ -293,12 +304,13 @@ static void FusedAttnForwardImpl(
auto q_tensor = TensorWrapper(q, q_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype); auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(),
is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
window_size_left, window_size_right, workspace_tensor.data(), stream); bias_type, mask_type, softmax_type, window_size_left, window_size_right,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
...@@ -307,12 +319,13 @@ static void FusedAttnForwardImpl( ...@@ -307,12 +319,13 @@ static void FusedAttnForwardImpl(
auto k_tensor = TensorWrapper(k, k_shape, dtype); auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd( nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
...@@ -444,6 +457,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -444,6 +457,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
// For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0 // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
min_num_segments = input_batch * max_segments_per_seq; min_num_segments = input_batch * max_segments_per_seq;
} }
auto dummy_d_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) { for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
// the last one is the largest which will be the returned workspace size // the last one is the largest which will be the returned workspace size
auto q_cu_seqlens_tensor = auto q_cu_seqlens_tensor =
...@@ -453,37 +469,38 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -453,37 +469,38 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto dummy_ragged_offset_tensor = auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), nvte_fused_attn_bwd_qkvpacked(
s_tensor.data(), // not used for F16 qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), s_tensor.data(), // not used for F16
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
bias_type, mask_type, window_size_left, window_size_right, dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability,
deterministic, query_workspace_tensor.data(), nullptr); qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
deterministic, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
window_size_left, window_size_right, deterministic, query_workspace_tensor.data(), dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
nullptr); window_size_right, deterministic, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(), dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
window_size_left, window_size_right, deterministic, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
query_workspace_tensor.data(), nullptr); window_size_right, deterministic, query_workspace_tensor.data(), nullptr);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
...@@ -515,14 +532,17 @@ static void FusedAttnBackwardImpl( ...@@ -515,14 +532,17 @@ static void FusedAttnBackwardImpl(
/* Output tensors */ /* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16 auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto dummy_d_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
/* Auxiliary tensors (propagated from the forward pass) */ /* Auxiliary tensors (propagated from the forward pass) */
NVTETensorPack aux_input_tensors; NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors); nvte_tensor_pack_create(&aux_input_tensors);
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias); softmax_aux, rng_state, bias);
...@@ -540,10 +560,11 @@ static void FusedAttnBackwardImpl( ...@@ -540,10 +560,11 @@ static void FusedAttnBackwardImpl(
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor,
bias_type, mask_type, window_size_left, window_size_right, dropout_probability, qkv_layout, bias_type, mask_type,
deterministic, workspace_tensor.data(), stream); softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape = auto kv_shape =
...@@ -562,10 +583,11 @@ static void FusedAttnBackwardImpl( ...@@ -562,10 +583,11 @@ static void FusedAttnBackwardImpl(
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
deterministic, workspace_tensor.data(), stream); mask_type, softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
...@@ -586,11 +608,12 @@ static void FusedAttnBackwardImpl( ...@@ -586,11 +608,12 @@ static void FusedAttnBackwardImpl(
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(), dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
window_size_right, deterministic, workspace_tensor.data(), stream); mask_type, softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
......
...@@ -13,6 +13,7 @@ import logging ...@@ -13,6 +13,7 @@ import logging
from packaging.version import Version as PkgVersion from packaging.version import Version as PkgVersion
import torch import torch
import torch.nn.functional as F
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
SplitAlongDim, SplitAlongDim,
...@@ -142,6 +143,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -142,6 +143,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext, attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -149,6 +151,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -149,6 +151,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.attention_type = attention_type self.attention_type = attention_type
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number self.layer_number = layer_number
self.softmax_type = softmax_type
def mask_func(x, y): def mask_func(x, y):
return ( return (
...@@ -185,6 +188,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -185,6 +188,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
softmax_offset: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Unfused attention fprop""" """Unfused attention fprop"""
assert ( assert (
...@@ -326,7 +330,21 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -326,7 +330,21 @@ class UnfusedDotProductAttention(torch.nn.Module):
dtype=query_layer.dtype dtype=query_layer.dtype
) )
# attention scores and attention mask [b, np, sq, sk] # add attention sink to the last column: [b, np, sq, sk+1]
if self.softmax_type != "vanilla":
matmul_result = torch.cat(
[
matmul_result,
softmax_offset.to(dtype=matmul_result.dtype).expand(
matmul_result.size(0), -1, matmul_result.size(2), -1
),
],
dim=-1,
)
attention_mask = F.pad(attention_mask, (0, 1), mode="constant", value=False)
attn_mask_type = "arbitrary"
# attention scores and attention mask
softmax_scale = self.layer_number if apply_qk_layer_scaling else None softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax( attention_probs = self.scale_mask_softmax(
matmul_result, attention_mask, attn_mask_type, softmax_scale matmul_result, attention_mask, attn_mask_type, softmax_scale
...@@ -337,6 +355,10 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -337,6 +355,10 @@ class UnfusedDotProductAttention(torch.nn.Module):
if "padding" in attn_mask_type: if "padding" in attn_mask_type:
attention_probs = attention_probs.masked_fill(attention_mask, 0) attention_probs = attention_probs.masked_fill(attention_mask, 0)
# remove attention sink: [b, np, sq, sk]
if self.softmax_type != "vanilla":
attention_probs = attention_probs[..., :-1]
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
...@@ -917,6 +939,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -917,6 +939,7 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
window_size, window_size,
rng_gen, rng_gen,
fused_attention_backend, fused_attention_backend,
...@@ -925,6 +948,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -925,6 +948,7 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta, fp8_meta,
quantizers, quantizers,
deterministic, deterministic,
softmax_offset,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
...@@ -997,8 +1021,10 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -997,8 +1021,10 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
window_size, window_size,
rng_gen, rng_gen,
softmax_offset,
) )
if is_output_fp8: if is_output_fp8:
out_ret = out_fp8 out_ret = out_fp8
...@@ -1059,8 +1085,10 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1059,8 +1085,10 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
window_size, window_size,
rng_gen, rng_gen,
softmax_offset,
) )
out_save = out_ret out_save = out_ret
fp8_tensors = (None, None, None, None) fp8_tensors = (None, None, None, None)
...@@ -1114,6 +1142,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1114,6 +1142,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout = qkv_layout ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.softmax_type = softmax_type
ctx.window_size = window_size ctx.window_size = window_size
ctx.fused_attention_backend = ( ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
...@@ -1224,6 +1253,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1224,6 +1253,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size, ctx.window_size,
ctx.deterministic, ctx.deterministic,
) )
...@@ -1287,42 +1317,17 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1287,42 +1317,17 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size, ctx.window_size,
ctx.deterministic, ctx.deterministic,
) )
# if no_bias or alibi, return dqkv d_bias = None
if ctx.attn_bias_type in ["no_bias", "alibi"]: if ctx.attn_bias_type not in ["no_bias", "alibi"]:
return ( d_bias = rest[0]
None, d_softmax_offset = None
None, if ctx.softmax_type != "vanilla":
None, d_softmax_offset = rest[1]
None,
None,
None,
None,
None,
None,
dq,
dk,
dv,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
# else, return (dqkv, dbias)
return ( return (
None, None,
None, None,
...@@ -1336,7 +1341,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1336,7 +1341,8 @@ class FusedAttnFunc(torch.autograd.Function):
dq, dq,
dk, dk,
dv, dv,
rest[0], d_bias,
None,
None, None,
None, None,
None, None,
...@@ -1351,6 +1357,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1351,6 +1357,7 @@ class FusedAttnFunc(torch.autograd.Function):
None, None,
None, None,
None, None,
d_softmax_offset,
) )
...@@ -1390,6 +1397,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1390,6 +1397,7 @@ class FusedAttention(torch.nn.Module):
attention_type: str = "self", attention_type: str = "self",
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
deterministic: bool = False, deterministic: bool = False,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1402,6 +1410,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1402,6 +1410,7 @@ class FusedAttention(torch.nn.Module):
) == "1" and get_device_compute_capability() == (9, 0) ) == "1" and get_device_compute_capability() == (9, 0)
self.layer_number = 1 if layer_number is None else layer_number self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic self.deterministic = deterministic
self.softmax_type = softmax_type
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
""" """
...@@ -1453,6 +1462,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1453,6 +1462,7 @@ class FusedAttention(torch.nn.Module):
quantizers=None, quantizers=None,
pad_between_seqs: bool = False, pad_between_seqs: bool = False,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
softmax_offset: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""fused attention fprop""" """fused attention fprop"""
assert ( assert (
...@@ -1603,6 +1613,8 @@ class FusedAttention(torch.nn.Module): ...@@ -1603,6 +1613,8 @@ class FusedAttention(torch.nn.Module):
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
quantizers=quantizers, quantizers=quantizers,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
softmax_type=self.softmax_type,
softmax_offset=softmax_offset,
) )
else: else:
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
...@@ -1626,6 +1638,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1626,6 +1638,7 @@ class FusedAttention(torch.nn.Module):
qkv_layout, qkv_layout,
core_attention_bias_type, core_attention_bias_type,
attn_mask_type, attn_mask_type,
self.softmax_type,
window_size, window_size,
None, # rng_gen None, # rng_gen
fused_attention_backend, fused_attention_backend,
...@@ -1634,6 +1647,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1634,6 +1647,7 @@ class FusedAttention(torch.nn.Module):
fp8_meta, fp8_meta,
quantizers, quantizers,
self.deterministic, self.deterministic,
softmax_offset,
) )
# ...hd -> ...(hd) # ...hd -> ...(hd)
......
...@@ -46,6 +46,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import ( ...@@ -46,6 +46,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
_cu_seqlens_info_with_cp_cache = {} _cu_seqlens_info_with_cp_cache = {}
_seq_chunk_ids_cache_for_reordering_before_attn = {} _seq_chunk_ids_cache_for_reordering_before_attn = {}
_seq_chunk_ids_cache_for_reordering_after_attn = {} _seq_chunk_ids_cache_for_reordering_after_attn = {}
_softmax_offset_chunk_ids_cache = {}
def flash_attn_p2p_communicate( def flash_attn_p2p_communicate(
...@@ -318,6 +319,55 @@ def flash_attn_a2a_communicate( ...@@ -318,6 +319,55 @@ def flash_attn_a2a_communicate(
return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs
def flash_attn_a2a_communicate_softmax_offset(
tensor: torch.Tensor,
h_dim: int,
cp_size: int,
cp_group: dist_group_type,
cp_stream: torch.cuda.Stream,
before_attn: bool,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Split/AllGather communication for softmax offset."""
if tensor is None:
return None
global _softmax_offset_chunk_ids_cache
device = tensor.device
if (cp_size, device) not in _softmax_offset_chunk_ids_cache:
chunk_ids = torch.arange(cp_size, dtype=torch.int32, device=device)
_softmax_offset_chunk_ids_cache[(cp_size, device)] = chunk_ids
else:
chunk_ids = _softmax_offset_chunk_ids_cache[(cp_size, device)]
if before_attn:
# softmax_offset: split round-robin to CP ranks
# [1, h, 1, 1] -> [1, cp, h//cp, 1, 1]
shape = tensor.shape
tensor = tensor.view(
*shape[:h_dim], cp_size, shape[h_dim] // cp_size, *shape[(h_dim + 1) :]
)
rank = get_distributed_rank(cp_group)
output = torch.index_select(tensor, dim=h_dim, index=chunk_ids[rank])
output = output.view(*shape[:h_dim], -1, *shape[(h_dim + 1) :])
else:
# d_softmax_offset: all-gather from all ranks to all ranks
# [1, h//cp, 1, 1] -> [1, h, 1, 1]
inp = tensor.view(-1)
output = torch.empty(cp_size * inp.shape[0], dtype=tensor.dtype, device=device)
with torch.cuda.stream(cp_stream):
torch.distributed.all_gather_into_tensor(
output,
inp,
group=cp_group,
async_op=False,
)
torch.cuda.current_stream().wait_stream(cp_stream)
output = output.view(
*tensor.shape[:h_dim], cp_size * tensor.shape[h_dim], *tensor.shape[h_dim + 1 :]
)
return output
def _get_cu_seqlens_info_with_cp( def _get_cu_seqlens_info_with_cp(
batch_size: int, batch_size: int,
max_seqlen: int, max_seqlen: int,
...@@ -1854,7 +1904,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1854,7 +1904,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd( dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_q_per_step[cp_size - i - 1],
...@@ -2014,7 +2064,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2014,7 +2064,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd( dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv // 2, ctx.max_seqlen_kv // 2,
cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_q_per_step[cp_size - i - 1],
...@@ -2171,7 +2221,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2171,7 +2221,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd( dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd(
ctx.max_seqlen_q // 2, ctx.max_seqlen_q // 2,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_q_per_step[cp_size - i - 1],
...@@ -2289,7 +2339,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2289,7 +2339,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd( dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_q_per_step[cp_size - i - 1],
...@@ -3122,7 +3172,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3122,7 +3172,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape)
if ctx.use_fused_attention: if ctx.use_fused_attention:
aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]]
dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd( dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
cu_seqlens_q, cu_seqlens_q,
...@@ -3283,6 +3333,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3283,6 +3333,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cp_stream, cp_stream,
quantizers, quantizers,
use_flash_attn_3, use_flash_attn_3,
softmax_type,
softmax_offset,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
...@@ -3391,6 +3443,10 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3391,6 +3443,10 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
q, k, v = flash_attn_a2a_communicate( q, k, v = flash_attn_a2a_communicate(
[q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
) )
if softmax_type != "vanilla":
softmax_offset = flash_attn_a2a_communicate_softmax_offset(
softmax_offset, 1, cp_size, cp_group, cp_stream, True
)
if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_f16, k_f16, v_f16 = q, k, v q_f16, k_f16, v_f16 = q, k, v
...@@ -3430,6 +3486,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3430,6 +3486,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cu_seqlens_kv_padded=cu_seqlens_kv_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
window_size=window_size, window_size=window_size,
**fp8_meta_kwargs, **fp8_meta_kwargs,
softmax_type=softmax_type,
softmax_offset=softmax_offset,
) )
if fp8: if fp8:
out = out._data out = out._data
...@@ -3532,6 +3590,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3532,6 +3590,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
ctx.use_flash_attn_3 = use_flash_attn_3 ctx.use_flash_attn_3 = use_flash_attn_3
ctx.softmax_type = softmax_type
ctx.qkv_dtype = qkv_dtype ctx.qkv_dtype = qkv_dtype
ctx.dQKV_quantizer = dQKV_quantizer ctx.dQKV_quantizer = dQKV_quantizer
...@@ -3695,7 +3754,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3695,7 +3754,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
dout_part, fake_dtype=dout_dtype, internal=True dout_part, fake_dtype=dout_dtype, internal=True
) )
dq, dk, dv, _ = fused_attn_bwd( dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
cu_seqlens_q, cu_seqlens_q,
...@@ -3719,6 +3778,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3719,6 +3778,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
window_size=ctx.window_size, window_size=ctx.window_size,
deterministic=ctx.deterministic, deterministic=ctx.deterministic,
**fp8_meta_kwargs, **fp8_meta_kwargs,
softmax_type=ctx.softmax_type,
) )
if ctx.fp8: if ctx.fp8:
dq = dq._data dq = dq._data
...@@ -3763,6 +3823,17 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3763,6 +3823,17 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]
d_bias = None
d_softmax_offset = None
if ctx.use_fused_attention:
if ctx.attn_bias_type not in ["no_bias", "alibi"]:
d_bias = rest[0]
if ctx.softmax_type != "vanilla":
d_softmax_offset = rest[1]
d_softmax_offset = flash_attn_a2a_communicate_softmax_offset(
d_softmax_offset, 1, cp_size, ctx.cp_group, ctx.cp_stream, False
)
if ctx.fp8: if ctx.fp8:
dq = ctx.dQKV_quantizer.create_tensor_from_data( dq = ctx.dQKV_quantizer.create_tensor_from_data(
dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
...@@ -3793,6 +3864,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3793,6 +3864,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
None, None,
None, None,
None, None,
d_bias,
None, None,
None, None,
None, None,
...@@ -3803,6 +3875,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3803,6 +3875,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
None, None,
None, None,
None, None,
d_softmax_offset,
) )
...@@ -3835,6 +3908,8 @@ def attn_forward_func_with_cp( ...@@ -3835,6 +3908,8 @@ def attn_forward_func_with_cp(
quantizers=None, quantizers=None,
pad_between_seqs=False, pad_between_seqs=False,
use_flash_attn_3=False, use_flash_attn_3=False,
softmax_type="vanilla",
softmax_offset=None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Attention implementation with context parallelism (CP). CP partitions tensors along the sequence Attention implementation with context parallelism (CP). CP partitions tensors along the sequence
...@@ -3911,23 +3986,23 @@ def attn_forward_func_with_cp( ...@@ -3911,23 +3986,23 @@ def attn_forward_func_with_cp(
else: else:
assert isinstance( assert isinstance(
cp_group, dist_group_type cp_group, dist_group_type
), f"Unsupported process group for CP communication type {cp_comm_type}!" ), f"cp_group must be {dist_group_type} type for {cp_comm_type=}!"
assert qkv_format in [ assert qkv_format in [
"bshd", "bshd",
"sbhd", "sbhd",
"thd", "thd",
], f"QKV format of {qkv_format} is not supported with context parallelism!" ], f"Context parallelism does not support {qkv_format=}!"
assert ( assert (
qkv_format != "sbhd" or use_fused_attention qkv_format != "sbhd" or use_fused_attention
), "FlashAttention does not support sbhd format!" ), "Context parallelism does not support FlashAttention backend with qkv_format = 'sbhd'!"
assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), ( assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
"""Attention bias is only supported with FusedAttention and "causal" """ "Context parallelism only supports attention bias with FusedAttention backend and"
"""or "no_mask" mask types!""" " non-padding mask types!"
) )
assert qkv_format != "thd" or ( assert qkv_format != "thd" or (
cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
), "cu_seqlens_padded cannot be None with context parallelism + THD format!" ), "cu_seqlens_padded can not be None for context parallelism and qkv_format = 'thd'!"
sliding_window_attn = ( sliding_window_attn = (
window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
...@@ -3935,13 +4010,28 @@ def attn_forward_func_with_cp( ...@@ -3935,13 +4010,28 @@ def attn_forward_func_with_cp(
assert not sliding_window_attn or cp_comm_type in [ assert not sliding_window_attn or cp_comm_type in [
"a2a", "a2a",
"all_gather", "all_gather",
], "The context parallel running configs cannot support sliding window attetnion!" ], "Context parallelism does not support sliding window attention with {cp_comm_type=}!"
enable_mla = k.shape[-1] != v.shape[-1] enable_mla = k.shape[-1] != v.shape[-1]
assert not enable_mla or cp_comm_type in [ assert not enable_mla or cp_comm_type in [
"p2p", "p2p",
"a2a+p2p", "a2a+p2p",
], "The context parallel running configs cannot support MLA!" ], "Context parallelism does not support MLA with {cp_comm_type=}!"
if fp8 and fp8_meta is not None:
if fp8_meta["recipe"].fp8_dpa:
assert (
softmax_type == "vanilla"
), "Context parallelism does not support {softmax_type=} with FP8 attention!"
assert (
softmax_type == "vanilla" or use_fused_attention
), "Context parallelism only supports {softmax_type=} with FusedAttention backend!"
assert (
softmax_type == "vanilla" or cp_comm_type == "a2a"
), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!"
assert (
softmax_type == "vanilla" or qkv_format != "thd"
), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!"
args = [ args = [
is_training, is_training,
...@@ -3982,7 +4072,17 @@ def attn_forward_func_with_cp( ...@@ -3982,7 +4072,17 @@ def attn_forward_func_with_cp(
args += [window_size, cp_group, cp_stream, use_flash_attn_3] args += [window_size, cp_group, cp_stream, use_flash_attn_3]
out = AttnFuncWithCPAndKVAllGather.apply(*args) out = AttnFuncWithCPAndKVAllGather.apply(*args)
elif cp_comm_type == "a2a": elif cp_comm_type == "a2a":
args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_flash_attn_3] args += [
window_size,
fp8,
fp8_meta,
cp_group,
cp_stream,
quantizers,
use_flash_attn_3,
softmax_type,
softmax_offset,
]
out = AttnFuncWithCPAndQKVOA2A.apply(*args) out = AttnFuncWithCPAndQKVOA2A.apply(*args)
else: else:
raise ValueError(f"Unsupported communication type: {cp_comm_type}!") raise ValueError(f"Unsupported communication type: {cp_comm_type}!")
......
...@@ -11,6 +11,7 @@ import warnings ...@@ -11,6 +11,7 @@ import warnings
import logging import logging
import torch import torch
from torch.nn.parameter import Parameter
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.utils import get_cudnn_version
...@@ -168,6 +169,17 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -168,6 +169,17 @@ class DotProductAttention(TransformerEngineBaseModule):
softmax_scale: Optional[float], default = `None` softmax_scale: Optional[float], default = `None`
softmax scale for the attention scores. If `None`, defaults to softmax scale for the attention scores. If `None`, defaults to
`1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`. `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -223,6 +235,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -223,6 +235,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_stream: torch.cuda.Stream = None, cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p", cp_comm_type: str = "p2p",
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -307,6 +320,20 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -307,6 +320,20 @@ class DotProductAttention(TransformerEngineBaseModule):
self.attention_type = attention_type self.attention_type = attention_type
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.softmax_type = softmax_type
if self.softmax_type == "vanilla":
self.softmax_offset = None
if self.softmax_type == "off-by-one":
self.softmax_offset = torch.zeros(
self.num_attention_heads // self.tp_size, device="cuda"
)
if self.softmax_type == "learnable":
self.register_parameter(
"softmax_offset",
Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="cuda")),
get_rng_state_tracker=get_rng_state_tracker,
)
attn_kwargs = { attn_kwargs = {
"attention_dropout": attention_dropout, "attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx, "attention_dropout_ctx": attention_dropout_ctx,
...@@ -328,6 +355,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -328,6 +355,7 @@ class DotProductAttention(TransformerEngineBaseModule):
layer_number=layer_number, layer_number=layer_number,
deterministic=self.deterministic, deterministic=self.deterministic,
**attn_kwargs, **attn_kwargs,
softmax_type=self.softmax_type,
) )
self.unfused_attention = UnfusedDotProductAttention( self.unfused_attention = UnfusedDotProductAttention(
...@@ -335,6 +363,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -335,6 +363,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_type=attention_type, attention_type=attention_type,
**attn_kwargs, **attn_kwargs,
layer_number=layer_number, layer_number=layer_number,
softmax_type=self.softmax_type,
) )
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
...@@ -634,6 +663,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -634,6 +663,7 @@ class DotProductAttention(TransformerEngineBaseModule):
query_layer, query_layer,
num_gemms=3, num_gemms=3,
allow_non_contiguous=True, allow_non_contiguous=True,
allow_different_data_and_param_types=self.softmax_type != "vanilla",
) as query_layer: ) as query_layer:
# checks for RNG # checks for RNG
if self.rng_states_tracker is not None and is_graph_capturing(): if self.rng_states_tracker is not None and is_graph_capturing():
...@@ -922,6 +952,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -922,6 +952,7 @@ class DotProductAttention(TransformerEngineBaseModule):
False False
), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes" ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"
# check if there is padding between sequences when qkv_format='thd'
if pad_between_seqs is None: if pad_between_seqs is None:
if qkv_format == "thd": if qkv_format == "thd":
pad_between_seqs = ( pad_between_seqs = (
...@@ -957,11 +988,13 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -957,11 +988,13 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
context_parallel=context_parallel, context_parallel=context_parallel,
cp_comm_type=self.cp_comm_type,
deterministic=self.deterministic, deterministic=self.deterministic,
is_training=self.training, is_training=self.training,
fp8=self.fp8, fp8=self.fp8,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
inference_params=inference_params, inference_params=inference_params,
softmax_type=self.softmax_type,
) )
global _attention_backends global _attention_backends
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
...@@ -1022,6 +1055,12 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1022,6 +1055,12 @@ class DotProductAttention(TransformerEngineBaseModule):
) )
# run attention # run attention
softmax_offset = (
self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32)
if self.softmax_offset is not None
else None
)
if use_flash_attention: if use_flash_attention:
if core_attention_bias_type == "alibi": if core_attention_bias_type == "alibi":
alibi_slopes, _ = dpa_utils.get_alibi( alibi_slopes, _ = dpa_utils.get_alibi(
...@@ -1071,7 +1110,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1071,7 +1110,6 @@ class DotProductAttention(TransformerEngineBaseModule):
bias_dtype=query_layer.dtype, bias_dtype=query_layer.dtype,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
) )
# checkpoint_core_attention=False
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.fused_attention, self.fused_attention,
...@@ -1101,6 +1139,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1101,6 +1139,7 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers, quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
inference_params=inference_params, inference_params=inference_params,
softmax_offset=softmax_offset,
) )
return self.fused_attention( return self.fused_attention(
query_layer, query_layer,
...@@ -1129,6 +1168,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1129,6 +1168,7 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers, quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
inference_params=inference_params, inference_params=inference_params,
softmax_offset=softmax_offset,
) )
from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled
...@@ -1157,6 +1197,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1157,6 +1197,7 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
inference_params=inference_params, inference_params=inference_params,
softmax_offset=softmax_offset,
) )
return self.unfused_attention( return self.unfused_attention(
_alibi_cache, _alibi_cache,
...@@ -1173,5 +1214,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1173,5 +1214,6 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
inference_params=inference_params, inference_params=inference_params,
softmax_offset=softmax_offset,
) )
return None return None
...@@ -24,6 +24,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -24,6 +24,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
QKVLayout, QKVLayout,
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
SoftmaxType,
FusedAttnBackend, FusedAttnBackend,
META_QKV, META_QKV,
META_DQKV, META_DQKV,
...@@ -206,6 +207,8 @@ class AttentionParams: ...@@ -206,6 +207,8 @@ class AttentionParams:
Attention dropout. Attention dropout.
context_parallel: bool, default = `False` context_parallel: bool, default = `False`
Whether context parallelism is used or not. Whether context parallelism is used or not.
cp_comm_type: str, default = "p2p"
The communication type of context parallelism.
deterministic: bool, default = `False` deterministic: bool, default = `False`
Whether to run `DotProductAttention` with determinism or not. Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True` is_training: bool, default = `True`
...@@ -216,6 +219,8 @@ class AttentionParams: ...@@ -216,6 +219,8 @@ class AttentionParams:
The FP8 metadata tensor of `DotProductAttention`. The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None` inference_params: Optional[InferenceParams], default = `None`
Inference-related parameters. See InferenceParams for details. Inference-related parameters. See InferenceParams for details.
softmax_type: str, default = "vanilla"
The type of softmax operation. See DotProductAttention for details.
""" """
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
...@@ -237,11 +242,13 @@ class AttentionParams: ...@@ -237,11 +242,13 @@ class AttentionParams:
pad_between_seqs: bool = False pad_between_seqs: bool = False
attention_dropout: float = 0.0 attention_dropout: float = 0.0
context_parallel: bool = False context_parallel: bool = False
cp_comm_type: str = "p2p"
deterministic: bool = False deterministic: bool = False
is_training: bool = True is_training: bool = True
fp8: bool = False fp8: bool = False
fp8_meta: Union[Dict[str, Any], None] = None fp8_meta: Union[Dict[str, Any], None] = None
inference_params: Optional[InferenceParams] = None inference_params: Optional[InferenceParams] = None
softmax_type: str = "vanilla"
def __eq__(self, other): def __eq__(self, other):
""" """
...@@ -308,11 +315,13 @@ def get_attention_backend( ...@@ -308,11 +315,13 @@ def get_attention_backend(
pad_between_seqs = attention_params.pad_between_seqs pad_between_seqs = attention_params.pad_between_seqs
attention_dropout = attention_params.attention_dropout attention_dropout = attention_params.attention_dropout
context_parallel = attention_params.context_parallel context_parallel = attention_params.context_parallel
cp_comm_type = attention_params.cp_comm_type
deterministic = attention_params.deterministic deterministic = attention_params.deterministic
is_training = attention_params.is_training is_training = attention_params.is_training
fp8 = attention_params.fp8 fp8 = attention_params.fp8
fp8_meta = attention_params.fp8_meta fp8_meta = attention_params.fp8_meta
inference_params = attention_params.inference_params inference_params = attention_params.inference_params
softmax_type = attention_params.softmax_type
# Run config # Run config
logger = logging.getLogger("DotProductAttention") logger = logging.getLogger("DotProductAttention")
...@@ -565,6 +574,51 @@ def get_attention_backend( ...@@ -565,6 +574,51 @@ def get_attention_backend(
logger.debug("Disabling FlashAttention 3 for dropout") logger.debug("Disabling FlashAttention 3 for dropout")
use_flash_attention_3 = False use_flash_attention_3 = False
# Filter: Softmax type
# context_parallel | softmax_type | supported backends
# ----------------------------------------------------------------------------------------------------
# no | vanilla | All
# no | off-by-one | FusedAttention, UnfusedDotProductAttention
# no | learnable | FusedAttention, UnfusedDotProductAttention
# yes | vanilla | FusedAttention, FlashAttention
# yes | off-by-one | FusedAttention
# yes | learnable | FusedAttention
if softmax_type != "vanilla":
logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type)
use_flash_attention = False
if fp8 and fp8_meta["recipe"].fp8_dpa:
logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type
)
use_unfused_attention = False
if qkv_format == "thd":
logger.debug(
"Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type
)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd",
softmax_type,
)
use_unfused_attention = False
if context_parallel:
logger.debug(
"Disabling UnfusedDotProductAttention for context parallelism with softmax_type"
" = %s",
softmax_type,
)
use_unfused_attention = False
if cp_comm_type != "a2a":
logger.debug(
"Disabling FusedAttention for context parallelism with softmax_type = %s and"
" cp_comm_type = %s",
softmax_type,
cp_comm_type,
)
use_fused_attention = False
# Filter: Context parallelism # Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends # qkv_format | attn_mask_type | attn_bias_type | supported backends
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
...@@ -806,6 +860,7 @@ def get_attention_backend( ...@@ -806,6 +860,7 @@ def get_attention_backend(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type], AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
attention_dropout, attention_dropout,
num_heads, num_heads,
num_gqa_groups, num_gqa_groups,
......
...@@ -135,6 +135,17 @@ class MultiheadAttention(torch.nn.Module): ...@@ -135,6 +135,17 @@ class MultiheadAttention(torch.nn.Module):
For that, please use `get_qkv_layout` to gain the layout information. For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None` name: str, default = `None`
name of the module, currently used for debugging purposes. name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -245,6 +256,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -245,6 +256,7 @@ class MultiheadAttention(torch.nn.Module):
qk_norm_before_rope: bool = False, qk_norm_before_rope: bool = False,
seq_length: Optional[int] = None, seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None, micro_batch_size: Optional[int] = None,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -262,6 +274,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -262,6 +274,7 @@ class MultiheadAttention(torch.nn.Module):
self.return_bias = return_bias self.return_bias = return_bias
self.cp_size = 1 self.cp_size = 1
self.cp_rank = 0 self.cp_rank = 0
self.softmax_type = softmax_type
kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads) kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)
...@@ -416,6 +429,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -416,6 +429,7 @@ class MultiheadAttention(torch.nn.Module):
tp_group=tp_group, tp_group=tp_group,
layer_number=self.layer_number, layer_number=self.layer_number,
attention_type=self.attention_type, attention_type=self.attention_type,
softmax_type=self.softmax_type,
) )
# Linear # Linear
......
...@@ -12,6 +12,7 @@ from transformer_engine_torch import ( ...@@ -12,6 +12,7 @@ from transformer_engine_torch import (
NVTE_QKV_Format, NVTE_QKV_Format,
NVTE_Bias_Type, NVTE_Bias_Type,
NVTE_Mask_Type, NVTE_Mask_Type,
NVTE_Softmax_Type,
NVTE_Fused_Attn_Backend, NVTE_Fused_Attn_Backend,
) )
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
...@@ -86,6 +87,12 @@ AttnMaskType = { ...@@ -86,6 +87,12 @@ AttnMaskType = {
"padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK, "padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
} }
SoftmaxType = {
"vanilla": NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX,
"off-by-one": NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX,
"learnable": NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX,
}
FusedAttnBackend = { FusedAttnBackend = {
"F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, "F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen,
"F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, "F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
...@@ -131,8 +138,10 @@ def fused_attn_fwd( ...@@ -131,8 +138,10 @@ def fused_attn_fwd(
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None, rng_gen: torch.Generator = None,
softmax_offset: torch.Tensor = None,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for separate QKV input. """Fused Attention FWD for separate QKV input.
...@@ -197,6 +206,8 @@ def fused_attn_fwd( ...@@ -197,6 +206,8 @@ def fused_attn_fwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1) window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
...@@ -205,6 +216,9 @@ def fused_attn_fwd( ...@@ -205,6 +216,9 @@ def fused_attn_fwd(
rng_gen: torch.Generator, default = None rng_gen: torch.Generator, default = None
random number generator; random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
softmax_offset: torch.Tensor, default = None
softmax offset tensor in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
Returns Returns
---------- ----------
...@@ -286,6 +300,7 @@ def fused_attn_fwd( ...@@ -286,6 +300,7 @@ def fused_attn_fwd(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type], AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size, window_size,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
...@@ -300,6 +315,7 @@ def fused_attn_fwd( ...@@ -300,6 +315,7 @@ def fused_attn_fwd(
s_quantizer, s_quantizer,
o_quantizer, o_quantizer,
attn_bias, attn_bias,
softmax_offset,
rng_gen, rng_gen,
rng_elts_per_thread, rng_elts_per_thread,
) )
...@@ -333,6 +349,7 @@ def fused_attn_bwd( ...@@ -333,6 +349,7 @@ def fused_attn_bwd(
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False, deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
...@@ -398,6 +415,8 @@ def fused_attn_bwd( ...@@ -398,6 +415,8 @@ def fused_attn_bwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1) window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
...@@ -417,6 +436,9 @@ def fused_attn_bwd( ...@@ -417,6 +436,9 @@ def fused_attn_bwd(
d_bias: torch.Tensor, optional d_bias: torch.Tensor, optional
gradient tensor of Bias when attn_bias_type is "pre_scale_bias" gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
or "post_scale_bias"; same data type and shape as Bias or "post_scale_bias"; same data type and shape as Bias
d_softmax_offset: torch.Tensor, optional
gradient tensor of softmax offset in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
""" """
if attn_scale is None: if attn_scale is None:
d = q.size(-1) d = q.size(-1)
...@@ -454,6 +476,7 @@ def fused_attn_bwd( ...@@ -454,6 +476,7 @@ def fused_attn_bwd(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type], AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size, window_size,
deterministic, deterministic,
cu_seqlens_q, cu_seqlens_q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment