Unverified Commit 5e4e0b2c authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Add sink attention support from cuDNN (#2148)



* first draft; debug plan failure
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* debug uid error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak params
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add grad in output
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up prints
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix prints in test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* address review comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix unfused grad; add softmax_type; add sink to bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix padding mask; add swa tests; remove requires_grad for off-by-one
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix indent
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix non-determinism and shapes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up prints
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add GQA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add CP A2A; dq/dk mismatches
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CP A2A; need cleaner solution
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CP A2A; pending cudnn kernel change
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix world size in unit test; avoid thd format
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix kernel_backend, dtype in unit test; fix head_dim for FP8 Hopper
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix thd logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 context
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak CP logging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* allow no_mask/padding for SWA(left,0)
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "allow no_mask/padding for SWA(left,0)"

This reverts commit 08b4ccc67a08b6882080b06aa715f541bb832aca.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add softmax_type to Jax
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add cuDNN version control
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* prettify tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* skip 9.13 for MLA, non 192/128
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename compare_with_error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* small cleanups and improvements
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix minor CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force sink/dsink to be float32
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* switch FE to GH FE
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* return to GH TE main FE commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FE to 1.14.1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up before CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* bump up cudnn version
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add backend selection guard for unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add docstring for softmax type enums in C
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 57b4d7bc
Subproject commit deda80e5372d50e925d7bf4f76c5db779be3fbd5
Subproject commit 1a7b4b78db44712fb9707d21cd2e3179f1fd88b8
......@@ -17,88 +17,18 @@ from test_attention_with_cp import model_configs_flash_attn, model_configs_fused
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.common.recipe import DelayedScaling
from utils import ModelConfig, compare_and_assert
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
def run_dpa_with_cp(
dtype="bf16",
model=None,
qkv_format="bshd",
kernel_backend="FlashAttention",
cp_comm_type="p2p",
fp8_mha=False,
def generate_input_shapes(
qkv_format: str,
config: ModelConfig,
world_size: int,
kernel_backend: str,
):
"""Test DotProductAttention module with context parallelism"""
# args are passed as strings
fp8_mha = fp8_mha == "True"
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)
print(f"[INFO] world_size:{world_size}, rank:{rank}")
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
# create flash attn comm group for CP
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert (
world_size % 2 == 0
), "Assuming CP size for A2A is 2, and CP size for P2P is (world_size // 2)!"
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)
if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha)
# instantiate core attn module
core_attn = DotProductAttention(
config.num_heads,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
)
core_attn = core_attn.cuda()
# create flash attn inputs
if qkv_format == "bshd":
q_input_shape = (
config.batch_size,
......@@ -191,34 +121,158 @@ def run_dpa_with_cp(
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
assert False, f"{qkv_format=} is not supported!"
return (
q_input_shape,
k_input_shape,
v_input_shape,
attn_output_shape,
cu_seqlens_q,
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
)
def get_tols(config, dtype):
if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
atol = 2.5e-2
rtol = 2.5e-2
else:
atol = 3.5e-2
rtol = 3.5e-2
rmse_tol = 0.01
elif dtype == "fp16":
atol = 5e-3
rtol = 5e-3
rmse_tol = 0.01
elif dtype == "fp8":
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.1
else:
assert False, f"{dtype=} is not supported!"
return atol, rtol, rmse_tol
def run_dpa_with_cp(
dtype="bf16",
model=None,
qkv_format="bshd",
kernel_backend="FlashAttention",
cp_comm_type="p2p",
fp8_mha=False,
log_level=logging.WARNING,
):
"""Test DotProductAttention module with context parallelism"""
logging.root.setLevel(log_level)
# set up environment variables and config
fp8_mha = fp8_mha == "True"
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type=} is not supported!"
if qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
# set up distributed group
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)
logging.info(f"[Rank {rank}] Setup: world_size {world_size}")
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
# set up communication group for CP
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert world_size % 2 == 0, (
"{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size"
" = 2."
)
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)
if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha)
# instantiate attention module
core_attn = DotProductAttention(
config.num_heads,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
softmax_type=config.softmax_type,
).cuda()
if config.softmax_type != "vanilla":
core_attn.softmax_offset.requires_grad = True
# generate attention inputs
(
q_input_shape,
k_input_shape,
v_input_shape,
attn_output_shape,
cu_seqlens_q,
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend)
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda()
for x in [q, k, v]:
x.requires_grad = True
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
if fp8_mha:
dout_quantizer = Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
scale=torch.tensor([1], dtype=torch.float32).cuda(),
amax=torch.tensor([0], dtype=torch.float32).cuda(),
)
# create flash attention bias
if config.attn_bias_type not in ["no_bias", "alibi"]:
attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv)
bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda()
else:
bias = None
# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
############ run without CP ############
logging.info(f"[Rank {rank}] Run without context parallelism")
if dtype == "fp8":
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
with fp8_context:
out = core_attn(
q,
......@@ -236,8 +290,30 @@ def run_dpa_with_cp(
out.backward(dout_fp8)
else:
out.backward(dout)
dq, dk, dv = q.grad, k.grad, v.grad
d_softmax_offset = None
if config.softmax_type != "vanilla":
d_softmax_offset = core_attn.softmax_offset.grad
# run core_attn wit CP
############ run with CP ############
logging.info(f"[Rank {rank}] Run with context parallelism")
# set up environment
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)
if config.softmax_type != "vanilla":
core_attn.softmax_offset.grad.zero_()
if dtype == "fp8":
core_attn.reset_fp8_meta_tensors()
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
# set up inputs
q_, k_, v_, dout_, *rest = [
x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])
]
......@@ -267,8 +343,6 @@ def run_dpa_with_cp(
)
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None:
bias_ = bias_.view(
......@@ -276,19 +350,8 @@ def run_dpa_with_cp(
)
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)
if dtype == "fp8":
core_attn.reset_fp8_meta_tensors()
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
# run attention
with fp8_context:
out_ = core_attn(
q_,
......@@ -306,18 +369,23 @@ def run_dpa_with_cp(
out_.backward(dout_fp8_)
else:
out_.backward(dout_)
if fp8_mha:
assert isinstance(out, Float8Tensor)
assert isinstance(out_, Float8Tensor)
out = out.dequantize()
out_ = out_.dequantize()
for x in [out_, q_.grad, k_.grad, v_.grad]:
# get outputs
dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad
d_softmax_offset_ = None
if config.softmax_type != "vanilla":
d_softmax_offset_ = core_attn.softmax_offset.grad.clone()
for x in [out_, dq_, dk_, dv_, d_softmax_offset_]:
if x is not None:
assert torch.all(~torch.isnan(x))
assert torch.all(~torch.isinf(x))
# compare results with and without CP
############ compare results between CP and no-CP ############
if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [
x.view(
......@@ -373,56 +441,70 @@ def run_dpa_with_cp(
).item()
== 0
)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=3.5e-2, rtol=3.5e-2)
elif dtype == "fp16":
tols = dict(atol=5e-3, rtol=5e-3)
elif dtype == "fp8":
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
else:
assert False, f"{dtype} is an unsupported dtype!"
def _rmse(a, b):
return torch.sqrt((a - b).square().mean()).item()
def _error(a, b):
if dtype != "fp8":
torch.testing.assert_close(a, b, **tols)
else:
try:
torch.testing.assert_close(a, b, **tols)
except Exception as e:
logging.debug(e)
rmse = _rmse(a, b)
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert (
rmse < rmse_tol * rmse_range
), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
atol, rtol, rmse_tol = get_tols(config, dtype)
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_]
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset]
names = ["out", "dq", "dk", "dv", "d_softmax_offset"]
names_cp = [x + "_cp" for x in names]
names_no_cp = [x + "_no_cp" for x in names]
is_fp8 = dtype == "fp8"
for i, t in enumerate(tensors_no_cp):
if t is not None:
if "softmax_offset" not in names[i]:
if qkv_format == "bshd":
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
_error(a[:, 0], b[:, 0])
_error(a[:, 1], b[:, 1])
compare_and_assert(
t[:, 0],
tensors_cp[i][:, 0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[:, 1],
tensors_cp[i][:, 1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
elif qkv_format == "sbhd":
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
_error(a[0], b[0])
_error(a[1], b[1])
compare_and_assert(
t[0],
tensors_cp[i][0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[1],
tensors_cp[i][1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
elif qkv_format == "thd":
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
_error(a, b)
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
)
logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches")
# destroy distribution group
dist.destroy_process_group()
......
......@@ -2,7 +2,6 @@
#
# See LICENSE for license information.
import logging
import math
import os
import sys
import pathlib
......@@ -50,27 +49,35 @@ _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import (
reset_rng_states,
compare_and_assert,
ModelConfig,
dtype_tols,
get_available_attention_backends,
)
# Only run FP8 tests on H100
# Check if hardware supports FP8
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
# Reset RNG seed and states
seed = 1234
# Reset RNG states
reset_rng_states()
# Reset FP8 global state manager
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
fp8.FP8GlobalStateManager.reset()
# Define F16 data types to test
param_types = [torch.float16]
if is_bf16_compatible():
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"base_1_0": ModelConfig(8, 128, 16, 64),
"base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
"base_2_0": ModelConfig(2, 2048, 24, 128),
......@@ -86,12 +93,6 @@ model_configs_base = {
}
param_types = [torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
......@@ -125,12 +126,12 @@ def test_dot_product_attention(
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
# Get backends
is_training = True
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
)
......@@ -141,7 +142,6 @@ def test_dot_product_attention(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
)
......@@ -227,6 +227,7 @@ def test_dot_product_attention(
is_training,
)
# Compare results
logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
if unfused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
......@@ -259,23 +260,102 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
"softmax_1_1": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="off-by-one"),
"softmax_1_2": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="learnable"),
"softmax_2_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
"softmax_2_1": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
),
"softmax_2_2": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
),
"softmax_3_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding"),
"softmax_3_1": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="off-by-one"
),
"softmax_3_2": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="learnable"
),
"softmax_4_0": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="causal"
),
"softmax_4_1": ModelConfig(
2,
2048,
64,
64,
num_gqa_groups=8,
window_size=(128, 0),
attn_mask_type="causal",
softmax_type="off-by-one",
),
"softmax_4_2": ModelConfig(
2,
2048,
64,
64,
num_gqa_groups=8,
window_size=(128, 0),
attn_mask_type="causal",
softmax_type="learnable",
),
"softmax_5_0": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="padding_causal"
),
"softmax_5_1": ModelConfig(
2,
2048,
64,
64,
num_gqa_groups=8,
window_size=(128, 0),
attn_mask_type="padding_causal",
softmax_type="off-by-one",
),
"softmax_5_2": ModelConfig(
2,
2048,
64,
64,
num_gqa_groups=8,
window_size=(128, 0),
attn_mask_type="padding_causal",
softmax_type="learnable",
),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax(dtype, model_configs, model):
"""Test DotProductAttention module with different softmax types"""
test_dot_product_attention(
dtype, model_configs, model, True, True, "bshd_bshd_bshd", False, False
)
model_configs_mla = {
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0
"mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1
# test: ModelConfig(b, sq, hq, dqk)
"mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128),
"mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),
"mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128),
"mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64),
"mla_2_1": ModelConfig(
1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
), # cross, 1
),
"mla_2_2": ModelConfig(
1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
), # cross, 1
"mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
"mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference
),
"mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64),
"mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128),
"mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128),
"mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128),
"mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160),
}
......@@ -289,7 +369,7 @@ def test_dpa_mla(dtype, model_configs, model):
model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
"mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"),
"mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
......@@ -344,18 +424,16 @@ def test_dpa_mask(dtype, model_configs, model):
model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
"bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
"bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
"bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), # skipped
"bias_1_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"
), # skipped
"bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"),
"bias_1_5": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"),
"bias_2_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
), # skipped
),
"bias_2_1": ModelConfig(
2,
128,
......@@ -364,10 +442,10 @@ model_configs_bias = {
max_seqlen_kv=256,
attn_mask_type="padding",
attn_bias_type="post_scale_bias",
), # skipped
),
"bias_2_2": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias"
), # skipped
),
"bias_2_3": ModelConfig(
2,
2048,
......@@ -376,13 +454,11 @@ model_configs_bias = {
max_seqlen_kv=4096,
attn_mask_type="padding",
attn_bias_type="post_scale_bias",
), # skipped
"bias_2_4": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"
), # skipped
),
"bias_2_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"),
"bias_2_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi"
), # skipped
),
"bias_3_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
......@@ -400,14 +476,14 @@ model_configs_bias = {
max_seqlen_kv=4096,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
), # skipped
),
"bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"),
"bias_3_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi"
), # skipped
),
"bias_4_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
), # skipped
),
"bias_4_1": ModelConfig(
2,
128,
......@@ -416,10 +492,10 @@ model_configs_bias = {
max_seqlen_kv=256,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
), # skipped
),
"bias_4_2": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
), # skipped
),
"bias_4_3": ModelConfig(
2,
2048,
......@@ -428,10 +504,10 @@ model_configs_bias = {
max_seqlen_kv=4096,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
), # skipped
),
"bias_4_4": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi"
), # skipped
),
"bias_4_5": ModelConfig(
2,
2048,
......@@ -440,7 +516,7 @@ model_configs_bias = {
max_seqlen_kv=4096,
attn_mask_type="padding_causal",
attn_bias_type="alibi",
), # skipped
),
}
......@@ -454,7 +530,7 @@ def test_dpa_bias(dtype, model_configs, model):
model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p,
# test: ModelConfig(b, sq, hq, dqk)
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"),
"bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
......@@ -492,7 +568,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"swa_1_1": ModelConfig(2, 2048, 16, 64),
"swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4),
"swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096),
......@@ -532,7 +608,7 @@ def test_dpa_sliding_window(dtype, model_configs, model):
model_configs_alibi_slopes = {
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
# test: ModelConfig(b, sq, hq, dqk)
"alibi_1_0": ModelConfig(
2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla"
),
......@@ -586,7 +662,7 @@ qkv_layouts = [
model_configs_layout = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"layout_0_0": ModelConfig(2, 128, 16, 64),
"layout_0_1": ModelConfig(
2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
......@@ -634,7 +710,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
"layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
"layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
......@@ -726,7 +802,6 @@ def _run_dot_product_attention(
is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
# Set RNG and environment varables
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
......@@ -989,9 +1064,12 @@ def _run_dot_product_attention(
tp_group=None,
layer_number=1,
attention_type=config.attn_type,
softmax_type=config.softmax_type,
).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()
if is_training and config.softmax_type != "vanilla":
block.softmax_offset.requires_grad = True
# Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
......@@ -1026,12 +1104,14 @@ def _run_dot_product_attention(
)
if is_training:
out.backward(d_out)
d_softmax_offset = None
if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if is_training:
return out, (q.grad, k.grad, v.grad)
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, (None, None, None)
return out, (None, None, None, d_softmax_offset)
if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
......@@ -1060,18 +1140,18 @@ def _run_dot_product_attention(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig)
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset)
else:
return out_orig, (None, None, None)
return out_orig, (None, None, None, d_softmax_offset)
else:
if is_training:
return out, (q.grad, k.grad, v.grad)
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, (None, None, None)
return out, (None, None, None, d_softmax_offset)
model_configs_te_layer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"),
"te_1_1": ModelConfig(
4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
......@@ -1436,6 +1516,7 @@ def _run_transformer_layer(
model_configs_fp8_extra_state = {
# test: ModelConfig(b, sq, hq, dqk)
"large": ModelConfig(2, 128, 4, 128, num_layers=1),
}
......@@ -1445,7 +1526,8 @@ model_configs_fp8_extra_state = {
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
def test_dpa_fp8_extra_state(model, dtype):
"""Test DotProductAttention module in FP8 with checkpointing"""
config = model_configs_fp8_extra_state[model]
# Test backend availability
is_training = True
......@@ -1459,9 +1541,9 @@ def test_sanity_attention_extra_state(model, dtype):
if not fused_attn_supported and not flash_attn_supported:
pytest.skip("No attention backend available.")
outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_attention_extra_state(
outputs = _run_dpa_fp8_extra_state(dtype, config, checkpoint=False)
outputs_checkpoint = _run_dpa_fp8_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_dpa_fp8_extra_state(
dtype, config, mimic_v1_6=True, checkpoint=True
)
......@@ -1483,7 +1565,8 @@ def test_sanity_attention_extra_state(model, dtype):
)
def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
"""Run DotProductAttention module in FP8 with checkpointing"""
steps = 10
path = "checkpoint.pt"
fp8_enabled = True
......@@ -1580,7 +1663,7 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False
model_configs_fp8_vs_f16 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"fp8_9": ModelConfig(2, 2048, 16, 128),
"fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
"fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
......@@ -1600,33 +1683,6 @@ qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
def _rmse(a, b):
return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())
def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item()))
logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item()))
try:
if a.dtype != b.dtype:
a = a.to(b.dtype)
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
rmse = _rmse(a, b)
logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert rmse < rmse_tol * rmse_range, (
name_a
+ " vs "
+ name_b
+ " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
)
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
......@@ -1638,6 +1694,7 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
@pytest.mark.parametrize("RoPE", [True, False])
@pytest.mark.parametrize("is_training", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training):
"""Test MultiHeadAttention module in FP8"""
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]
......@@ -1691,7 +1748,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output"))
if flash_attn_supported:
_error(
compare_and_assert(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
"flash_attn_fwd_fp8",
......@@ -1699,8 +1756,9 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
atol,
rtol,
rmse_tol,
True,
)
_error(
compare_and_assert(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
......@@ -1708,12 +1766,13 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
atol,
rtol,
rmse_tol,
True,
)
if is_training:
for i in range(len(param_names[:1])):
logging.debug("========== {:^25s} ==========".format(param_names[i]))
_error(
compare_and_assert(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
......@@ -1721,10 +1780,12 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
atol,
rtol,
rmse_tol,
True,
)
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training):
"""Run MultiHeadAttention module in FP8"""
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
......@@ -1851,6 +1912,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("is_training", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
"""Test DotProductAttention module in FP8"""
config = model_configs_fp8_vs_f16[model]
# TODO(cyang): think of another way to verify dropout results
......@@ -1920,7 +1982,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output"))
if flash_attn_supported:
_error(
compare_and_assert(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
"flash_attn_fwd_fp8",
......@@ -1928,6 +1990,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
atol,
rtol,
rmse_tol,
True,
)
if config.dropout_p != 0.0:
# test cuDNN FP8 dropout
......@@ -1935,7 +1998,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
fused_attn_fwd_fp8 == 1
), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s."
else:
_error(
compare_and_assert(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
......@@ -1943,11 +2006,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
atol,
rtol,
rmse_tol,
True,
)
if is_training:
for i, _ in enumerate(fused_attn_bwd_f16):
logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
_error(
compare_and_assert(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
......@@ -1955,11 +2019,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
atol,
rtol,
rmse_tol,
True,
)
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
"""Run DotProductAttention module in FP8"""
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
......@@ -2092,7 +2157,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
model_configs_fp8 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"fp8_1": ModelConfig(1, 512, 1, 64),
"fp8_2": ModelConfig(4, 512, 16, 64),
"fp8_3": ModelConfig(1, 2048, 1, 128),
......@@ -2147,7 +2212,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.13
_error(
compare_and_assert(
fused_attn_fwd_fp8,
unfused_attn_fwd_f16,
"fused_attn_fwd_fp8",
......@@ -2155,8 +2220,9 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
atol,
rtol,
rmse_tol,
True,
)
_error(
compare_and_assert(
fused_attn_bwd_fp8,
unfused_attn_bwd_f16,
"fused_attn_bwd_fp8",
......@@ -2164,6 +2230,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
atol,
rtol,
rmse_tol,
True,
)
......
......@@ -6,6 +6,7 @@ import os
import subprocess
import sys
import pathlib
import logging
import pytest
import torch
......@@ -19,13 +20,15 @@ _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig, get_available_attention_backends
pytest_logging_level = logging.getLevelName(logging.root.level)
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
......@@ -72,6 +75,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")
config = model_configs_flash_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
......@@ -89,6 +94,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
)
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!")
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16}
available_backends, *_ = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype],
qkv_layout="_".join([qkv_format] * 3),
)
flash_attn_supported, *_ = available_backends
if not flash_attn_supported:
pytest.skip("No attention backend available.")
subprocess.run(
get_bash_arguments(
......@@ -98,13 +112,14 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
qkv_format=qkv_format,
kernel_backend="FlashAttention",
cp_comm_type=cp_comm_type,
log_level=pytest_logging_level,
),
check=True,
)
model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(
......@@ -135,6 +150,15 @@ model_configs_fused_attn = {
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
"cp_4_0": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla"
), # GQA
"cp_4_1": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
), # GQA
"cp_4_2": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
), # GQA
}
......@@ -158,6 +182,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("FP8 attention is only supported on sm90+!")
config = model_configs_fused_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
......@@ -191,13 +217,22 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!")
if dtype == "fp8" and config.softmax_type != "vanilla":
pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!")
if config.softmax_type != "vanilla" and cp_comm_type != "a2a":
pytest.skip(
"CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!"
)
if config.softmax_type != "vanilla" and qkv_format == "thd":
pytest.skip(
"CP implementation does not support qkv_format=thd for non-vanilla softmax types!"
)
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype],
qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn,
qkv_layout="_".join([qkv_format] * 3),
window_size=config.window_size,
context_parallel=True,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
......@@ -212,6 +247,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
kernel_backend="FusedAttention",
cp_comm_type=cp_comm_type,
fp8_mha=fp8_mha,
log_level=pytest_logging_level,
),
check=True,
)
......@@ -469,7 +469,6 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=False,
is_training=False,
fp8=is_fp8,
......
......@@ -20,6 +20,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
get_attention_backend,
AttentionParams,
AttentionLogging,
check_set_window_size,
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
......@@ -137,6 +138,31 @@ def reset_rng_states() -> None:
torch.cuda.set_rng_state(cuda_rng_state)
def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8):
if not is_fp8:
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
return
try:
if a.dtype != b.dtype:
a = a.to(b.dtype)
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
rmse = torch.sqrt((a - b).square().mean()).item()
logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert rmse < rmse_tol * rmse_range, (
name_a
+ " vs "
+ name_b
+ " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
)
class ModelConfig:
def __init__(
self,
......@@ -147,12 +173,15 @@ class ModelConfig:
max_seqlen_kv: int = None,
num_gqa_groups: int = None,
head_dim_v: int = None,
softmax_type: str = "vanilla",
dropout_p: float = 0.0,
attn_mask_type: str = "no_mask",
attn_bias_type: str = "no_bias",
alibi_type: str = "none",
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
context_parallel: bool = False,
cp_comm_type: str = "p2p",
total_requests: int = None,
max_ctx_len: int = None,
num_layers: int = 1,
......@@ -171,13 +200,16 @@ class ModelConfig:
self.kv_channels = (self.head_dim_qk, self.head_dim_v)
self.hidden_size = self.num_heads * self.head_dim_qk
self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v
self.softmax_type = softmax_type
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross"
self.bias_shape = bias_shape
self.window_size = window_size
self.window_size = check_set_window_size(self.attn_mask_type, window_size)
self.context_parallel = context_parallel
self.cp_comm_type = cp_comm_type
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
......@@ -198,9 +230,7 @@ def get_available_attention_backends(
config: ModelConfig,
qkv_dtype: torch.dtype,
qkv_layout: str,
window_size: Tuple[int, int] = (-1, -1),
pad_between_seqs: bool = False,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
......@@ -250,19 +280,21 @@ def get_available_attention_backends(
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
window_size=config.window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
context_parallel=config.context_parallel,
cp_comm_type=config.cp_comm_type,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
is_training=is_training,
inference_params=inference_params,
softmax_type=config.softmax_type,
)
(
use_flash_attention,
......
......@@ -135,9 +135,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
// select a backend for fused attention
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
......@@ -175,7 +176,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// sm90: fwd d<=256, bwd d=128 only
// sm100: fwd d<=128, bwd d<=128
((sm_arch_ < 100 && head_dim_qk <= 256 && head_dim_v <= 256) ||
((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) ||
(sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) ||
(sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) &&
head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
......@@ -183,7 +185,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
!requires_64bit_ragged_offset &&
!requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) &&
// 9.10.0: known bugs with SDPA FP8
(cudnn_runtime_version != 91000)) {
if (cudnn_runtime_version >= 8900) {
......@@ -213,7 +215,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) &&
((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) &&
!requires_64bit_ragged_offset) {
!requires_64bit_ragged_offset &&
(softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) {
flag_m512 = true;
}
if (
......@@ -363,7 +366,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// check 64-bit ragged offset support
(supported_ragged_offset_size) &&
// 9.10.0/9.10.1: known bugs with SDPA F16
(cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) {
(cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001) &&
// softmax type
// pre-9.13.1: vanilla
// 9.13.1+: vanilla, off-by-one, learnable
(cudnn_runtime_version >= 91301 ||
(cudnn_runtime_version < 91301 &&
softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
......@@ -405,14 +414,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
}
// NVTE fused attention FWD with packed QKV
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
size_t max_seqlen, bool is_training, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) {
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine;
......@@ -421,6 +432,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const Tensor *input_rng_state = convertNVTETensorCheck(rng_state);
const Tensor *input_QKV = convertNVTETensorCheck(QKV);
const Tensor *input_Bias = convertNVTETensorCheck(Bias);
const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset);
Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = convertNVTETensor(workspace);
......@@ -447,8 +459,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h,
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -463,9 +475,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O,
Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace,
stream, handle);
attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, input_Bias,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded,
input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -487,10 +499,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
size_t max_seqlen, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
......@@ -505,6 +518,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
Tensor *input_output_dP = convertNVTETensorCheck(dP);
Tensor *output_dQKV = convertNVTETensorCheck(dQKV);
Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset);
Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_QKV->data.shape.size();
......@@ -529,8 +543,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
max_seqlen, d, d, window_size_left, window_size_right);
true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h,
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -543,19 +557,22 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *input_Bias, *input_rng_state;
size_t i = 0;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_Bias, *input_SoftmaxOffset;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
} else {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO,
input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded,
input_rng_state, wkspace, stream, handle);
softmax_type, window_size_left, window_size_right, deterministic, input_QKV, input_O,
input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQKV, output_dBias,
output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace,
stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
......@@ -580,14 +597,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
}
// NVTE fused attention FWD with packed KV
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream) {
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -600,6 +618,7 @@ void nvte_fused_attn_fwd_kvpacked(
const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_KV = convertNVTETensorCheck(KV);
const Tensor *input_Bias = convertNVTETensorCheck(Bias);
const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset);
Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = convertNVTETensor(workspace);
......@@ -660,8 +679,8 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -677,10 +696,11 @@ void nvte_fused_attn_fwd_kvpacked(
fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right,
input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left,
window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, output_O,
Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -702,12 +722,12 @@ void nvte_fused_attn_fwd_kvpacked(
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream) {
NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -723,6 +743,7 @@ void nvte_fused_attn_bwd_kvpacked(
Tensor *output_dQ = convertNVTETensorCheck(dQ);
Tensor *output_dKV = convertNVTETensorCheck(dKV);
Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset);
Tensor *wkspace = convertNVTETensor(workspace);
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
......@@ -755,8 +776,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -770,20 +791,23 @@ void nvte_fused_attn_bwd_kvpacked(
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8903)
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *input_Bias, *input_rng_state;
size_t i = 0;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_Bias, *input_SoftmaxOffset;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
} else {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q,
input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic,
input_Q, input_KV, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ,
output_dKV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
#else
const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention "
......@@ -809,16 +833,17 @@ void nvte_fused_attn_bwd_kvpacked(
}
// NVTE fused attention FWD with separate Q, K and V
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream) {
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -832,6 +857,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const Tensor *input_K = convertNVTETensorCheck(K);
const Tensor *input_V = convertNVTETensorCheck(V);
const Tensor *input_Bias = convertNVTETensorCheck(Bias);
const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset);
Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = convertNVTETensor(workspace);
......@@ -886,8 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -903,10 +929,11 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right,
input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left,
window_size_right, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O,
Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -928,14 +955,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream) {
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -953,6 +981,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
Tensor *output_dK = convertNVTETensorCheck(dK);
Tensor *output_dV = convertNVTETensorCheck(dV);
Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset);
Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_Q->data.shape.size();
......@@ -978,8 +1007,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -993,19 +1022,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *input_Bias, *input_rng_state;
size_t i = 0;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_Bias, *input_SoftmaxOffset;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
} else {
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic,
input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK,
output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right,
deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias,
input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias,
output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
......
......@@ -54,10 +54,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k,
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias,
void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK,
void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrSoftmaxStats,
void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -75,6 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
is_causal = true;
is_bottom_right = false;
}
bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX);
bool is_dropout = (is_training && dropout_probability != 0.0f);
NVTE_QKV_Format q_format = nvte_get_q_format(layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(layout);
......@@ -98,8 +100,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
s_q = is_ragged_q ? max_t_q : s_q;
s_kv = is_ragged_kv ? max_t_kv : s_kv;
}
const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
try {
FADescriptor_v1 descriptor{b,
h,
......@@ -122,6 +124,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout,
bias_type,
mask_type,
softmax_type,
window_size_left,
window_size_right,
true,
......@@ -138,6 +141,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // O
std::shared_ptr<fe::graph::Tensor_attributes>, // Stats
std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // softmax_offset
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // page_table_k
......@@ -168,7 +172,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale, softmax_offset;
std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> page_table_k, page_table_v;
std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o,
......@@ -302,6 +306,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset);
}
if (is_softmax_offset) {
softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("softmax_offset")
.set_dim({1, h, 1, 1})
.set_stride({h, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
sdpa_options.set_sink_token(softmax_offset);
}
auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options);
std::vector<int64_t> o_stride(4);
......@@ -338,6 +351,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O);
auto Stats_tuple = std::make_tuple(Stats);
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto softmax_offset_tuple =
is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr);
auto padding_tuple =
is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto page_table_tuple = is_paged_kv ? std::make_tuple(page_table_k, page_table_v)
......@@ -358,17 +373,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple,
page_table_tuple, offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple);
auto return_tuple =
std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple,
softmax_offset_tuple, padding_tuple, page_table_tuple, offset_qo_tuple,
offset_kv_tuple, offset_s_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, page_table_k, page_table_v,
offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] =
get_graph(sdpa_f16_fprop_cache, descriptor);
auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, softmax_offset, seq_q, seq_kv,
page_table_k, page_table_v, offset_q, offset_o, offset_k, offset_v, offset_stats,
dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor);
// Exit to request upper level API to allocate memory if needed
// n.b. Care should be taken to align each of the added worksapce tensors to their type.
......@@ -473,6 +489,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
variant_pack[dropout_seed] = devPtrDropoutSeed;
variant_pack[dropout_offset] = devPtrDropoutOffset;
}
if (is_softmax_offset) {
variant_pack[softmax_offset] = devPtrSoftmaxOffset;
}
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
......@@ -483,14 +504,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose,
void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias,
void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ,
void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats,
void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV,
void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed,
void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
......@@ -506,6 +527,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
is_causal = true;
is_bottom_right = false;
}
bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX);
bool is_dropout = (dropout_probability != 0.0f);
NVTE_QKV_Format q_format = nvte_get_q_format(layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(layout);
......@@ -558,6 +580,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
layout,
bias_type,
mask_type,
softmax_type,
window_size_left,
window_size_right,
deterministic,
......@@ -579,6 +602,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // dV
std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // dBias
std::shared_ptr<fe::graph::Tensor_attributes>, // softmax_offset
std::shared_ptr<fe::graph::Tensor_attributes>, // d_softmax_offset
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_q
......@@ -608,7 +633,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_compute_data_type(fe::DataType_t::FLOAT);
std::shared_ptr<fe::graph::Tensor_attributes> q, k, v, o, dO, stats, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> bias, dBias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> bias, dBias, softmax_offset, d_softmax_offset,
seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o,
offset_stats;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
......@@ -771,6 +797,21 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset);
}
if (is_softmax_offset) {
softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("softmax_offset")
.set_dim({1, h, 1, 1})
.set_stride({h, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
sdpa_backward_options.set_sink_token(softmax_offset);
d_softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("d_softmax_offset")
.set_dim({1, h, 1, 1})
.set_stride({h, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
sdpa_backward_options.set_dsink_token(d_softmax_offset);
}
auto [dQ, dK, dV] = mha_graph->sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options);
dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride);
......@@ -796,6 +837,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>> // dV
key_tensors_tuple = std::make_tuple(q, k, v, o, dO, stats, attn_scale, dQ, dK, dV);
auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr);
auto softmax_offset_tuple = is_softmax_offset
? std::make_tuple(softmax_offset, d_softmax_offset)
: std::make_tuple(nullptr, nullptr);
auto padding_tuple =
is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto offset_qo_tuple =
......@@ -814,17 +858,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
auto return_tuple =
std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple,
offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple);
auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple,
softmax_offset_tuple, padding_tuple, offset_qo_tuple,
offset_kv_tuple, offset_s_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv,
offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] =
get_graph(sdpa_f16_bprop_cache, descriptor);
auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, softmax_offset,
d_softmax_offset, seq_q, seq_kv, offset_q, offset_o, offset_k, offset_v, offset_stats,
dropout_seed, dropout_offset] = get_graph(sdpa_f16_bprop_cache, descriptor);
// Exit to request upper level API to allocate memory if needed
// n.b. Care should be taken to align each of the added worksapce tensors to their type.
......@@ -938,6 +982,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
variant_pack[dropout_offset] = devPtrDropoutOffset;
}
if (is_softmax_offset) {
variant_pack[softmax_offset] = devPtrSoftmaxOffset;
variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset;
}
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
......@@ -949,8 +998,9 @@ using namespace transformer_engine::fused_attn;
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -977,6 +1027,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
void *devPtrSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
}
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
......@@ -990,11 +1044,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
max_tokens = get_max_tokens(num_tokens);
}
size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens, num_attn_heads, 1};
......@@ -1002,41 +1055,39 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = nullptr;
output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1};
output_softmax_offset->data.dtype = DType::kFloat32;
}
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Aux_CTX_Tensors->size = i;
} else if (Aux_CTX_Tensors->size >= 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = devPtrBias;
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = devPtrSoftmaxOffset;
}
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
......@@ -1050,11 +1101,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim,
max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets,
devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream,
handle);
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS,
devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr,
nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1074,9 +1125,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -1122,6 +1174,12 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrSoftmaxOffset = nullptr;
void *devPtrdSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr;
}
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr;
......@@ -1135,11 +1193,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim,
max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, devPtrK,
devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens,
devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
bias_type, mask_type, softmax_type, window_size_left, window_size_right, deterministic,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1161,12 +1219,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
......@@ -1192,6 +1250,10 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
void *devPtrSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
}
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
......@@ -1216,11 +1278,10 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
max_tokens_kv = get_max_tokens(num_tokens_kv);
}
size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
......@@ -1228,41 +1289,39 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = nullptr;
output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1};
output_softmax_offset->data.dtype = DType::kFloat32;
}
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Aux_CTX_Tensors->size = i;
} else if (Aux_CTX_Tensors->size >= 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = devPtrBias;
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = devPtrSoftmaxOffset;
}
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
......@@ -1277,11 +1336,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim,
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS,
devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1302,10 +1361,11 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
......@@ -1359,6 +1419,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrSoftmaxOffset = nullptr;
void *devPtrdSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr;
}
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
......@@ -1374,9 +1440,10 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ,
devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV,
devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
......@@ -1401,12 +1468,13 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
......@@ -1425,6 +1493,10 @@ void fused_attn_arbitrary_seqlen_fwd(
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
void *devPtrSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
}
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
......@@ -1446,11 +1518,10 @@ void fused_attn_arbitrary_seqlen_fwd(
max_tokens_kv = get_max_tokens(num_tokens_kv);
}
size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
......@@ -1458,41 +1529,39 @@ void fused_attn_arbitrary_seqlen_fwd(
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = nullptr;
output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1};
output_softmax_offset->data.dtype = DType::kFloat32;
}
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Aux_CTX_Tensors->size = i;
} else if (Aux_CTX_Tensors->size >= 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = devPtrBias;
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = devPtrSoftmaxOffset;
}
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
......@@ -1507,11 +1576,11 @@ void fused_attn_arbitrary_seqlen_fwd(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS,
devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1532,13 +1601,14 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
......@@ -1577,6 +1647,12 @@ void fused_attn_arbitrary_seqlen_bwd(
void *devPtrdV = output_dV->data.dptr;
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrSoftmaxOffset = nullptr;
void *devPtrdSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr;
}
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
......@@ -1592,9 +1668,10 @@ void fused_attn_arbitrary_seqlen_bwd(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ,
devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV,
devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
......
......@@ -21,17 +21,19 @@ namespace transformer_engine {
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
......@@ -41,21 +43,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
......@@ -66,24 +69,26 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine
......
......@@ -1695,6 +1695,7 @@ void fused_attn_fp8_fwd_impl_v1(
layout,
bias_type,
mask_type,
NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0,
0,
true,
......@@ -2000,6 +2001,7 @@ void fused_attn_fp8_bwd_impl_v1(
layout,
bias_type,
mask_type,
NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0,
0,
false,
......
......@@ -107,6 +107,7 @@ struct FADescriptor_v1 {
NVTE_QKV_Layout layout;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
NVTE_Softmax_Type softmax_type;
std::int64_t window_size_left;
std::int64_t window_size_right;
bool deterministic;
......@@ -116,14 +117,15 @@ struct FADescriptor_v1 {
bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left,
window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) <
attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, bias_type, fwd_tensor_type,
bwd_tensor_type) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left,
rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type,
rhs.bwd_tensor_type);
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
rhs.fwd_tensor_type, rhs.bwd_tensor_type);
}
};
......
......@@ -124,6 +124,24 @@ enum NVTE_Mask_Type {
NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5,
};
/*! \enum NVTE_Softmax_Type
* \brief Attention softmax types as described in
* Efficient Streaming Language Models with Attention Sinks (https://arxiv.org/pdf/2309.17453v3).
* For a given attention score S = Q*K^T, different softmax types perform different operations on S,
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* where alpha is a learnable parameter in shape [H].
*/
enum NVTE_Softmax_Type {
/*! Vanilla softmax */
NVTE_VANILLA_SOFTMAX = 0,
/*! Off-by-one softmax */
NVTE_OFF_BY_ONE_SOFTMAX = 1,
/*! Learnable softmax */
NVTE_LEARNABLE_SOFTMAX = 2,
};
/*! \enum NVTE_Fused_Attn_Backend
* \brief Fused attention backends
*/
......@@ -178,6 +196,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] softmax_type The attention softmax type.
* \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V.
......@@ -190,9 +209,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right);
/*! \brief Compute dot product attention with packed QKV input.
*
......@@ -224,6 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
*
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
......@@ -239,19 +260,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen,
bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
......@@ -284,6 +305,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* e.g. M, ZInv, rng_state.
* \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing,
......@@ -293,6 +315,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
......@@ -302,10 +325,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
size_t max_seqlen, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream);
......@@ -340,6 +364,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in 2HD or H2D layouts.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
......@@ -361,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
......@@ -368,13 +394,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input.
*
......@@ -409,6 +437,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
......@@ -422,6 +451,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
......@@ -431,12 +461,12 @@ void nvte_fused_attn_fwd_kvpacked(
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream);
NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V.
*
......@@ -473,6 +503,7 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] K The K tensor.
* \param[in] V The V tensor.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
......@@ -494,22 +525,24 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream);
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
*
......@@ -549,6 +582,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[out] dK The gradient of the K tensor.
* \param[out] dV The gradient of the V tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
......@@ -562,6 +596,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
......@@ -571,14 +606,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream);
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset.
*
......
......@@ -36,6 +36,10 @@
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_Softmax_Type>(m, "NVTE_Softmax_Type", pybind11::module_local()) \
.value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) \
.value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX) \
.value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX); \
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \
......
......@@ -18,10 +18,11 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t qk_head_dim, size_t v_head_dim,
int64_t window_size_left, int64_t window_size_right) {
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen,
kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
return backend;
}
......@@ -146,6 +147,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
auto dummy_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
......@@ -172,28 +176,30 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen");
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, query_workspace_tensor.data(), nullptr);
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
}
......@@ -262,10 +268,15 @@ static void FusedAttnForwardImpl(
/* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto dummy_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
......@@ -280,12 +291,12 @@ static void FusedAttnForwardImpl(
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, is_training, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, workspace_tensor.data(), stream);
q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape =
......@@ -293,12 +304,13 @@ static void FusedAttnForwardImpl(
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, softmax_type, window_size_left, window_size_right,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
......@@ -307,12 +319,13 @@ static void FusedAttnForwardImpl(
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
......@@ -444,6 +457,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
// For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
min_num_segments = input_batch * max_segments_per_seq;
}
auto dummy_d_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
// the last one is the largest which will be the returned workspace size
auto q_cu_seqlens_tensor =
......@@ -453,13 +469,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right,
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
deterministic, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_bwd_kvpacked(
......@@ -467,23 +484,23 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, deterministic, query_workspace_tensor.data(),
nullptr);
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, deterministic,
query_workspace_tensor.data(), nullptr);
dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
......@@ -515,14 +532,17 @@ static void FusedAttnBackwardImpl(
/* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto dummy_d_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
/* Auxiliary tensors (propagated from the forward pass) */
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias);
......@@ -540,10 +560,11 @@ static void FusedAttnBackwardImpl(
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right,
deterministic, workspace_tensor.data(), stream);
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape =
......@@ -562,10 +583,11 @@ static void FusedAttnBackwardImpl(
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
deterministic, workspace_tensor.data(), stream);
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
......@@ -586,11 +608,12 @@ static void FusedAttnBackwardImpl(
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, deterministic, workspace_tensor.data(), stream);
dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
......
......@@ -13,6 +13,7 @@ import logging
from packaging.version import Version as PkgVersion
import torch
import torch.nn.functional as F
import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import (
SplitAlongDim,
......@@ -142,6 +143,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -149,6 +151,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.attention_type = attention_type
self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number
self.softmax_type = softmax_type
def mask_func(x, y):
return (
......@@ -185,6 +188,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
inference_params: Optional[InferenceParams] = None,
softmax_offset: torch.Tensor = None,
) -> torch.Tensor:
"""Unfused attention fprop"""
assert (
......@@ -326,7 +330,21 @@ class UnfusedDotProductAttention(torch.nn.Module):
dtype=query_layer.dtype
)
# attention scores and attention mask [b, np, sq, sk]
# add attention sink to the last column: [b, np, sq, sk+1]
if self.softmax_type != "vanilla":
matmul_result = torch.cat(
[
matmul_result,
softmax_offset.to(dtype=matmul_result.dtype).expand(
matmul_result.size(0), -1, matmul_result.size(2), -1
),
],
dim=-1,
)
attention_mask = F.pad(attention_mask, (0, 1), mode="constant", value=False)
attn_mask_type = "arbitrary"
# attention scores and attention mask
softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax(
matmul_result, attention_mask, attn_mask_type, softmax_scale
......@@ -337,6 +355,10 @@ class UnfusedDotProductAttention(torch.nn.Module):
if "padding" in attn_mask_type:
attention_probs = attention_probs.masked_fill(attention_mask, 0)
# remove attention sink: [b, np, sq, sk]
if self.softmax_type != "vanilla":
attention_probs = attention_probs[..., :-1]
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with self.attention_dropout_ctx():
......@@ -917,6 +939,7 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
window_size,
rng_gen,
fused_attention_backend,
......@@ -925,6 +948,7 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta,
quantizers,
deterministic,
softmax_offset,
):
# pylint: disable=missing-function-docstring
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
......@@ -997,8 +1021,10 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
window_size,
rng_gen,
softmax_offset,
)
if is_output_fp8:
out_ret = out_fp8
......@@ -1059,8 +1085,10 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
window_size,
rng_gen,
softmax_offset,
)
out_save = out_ret
fp8_tensors = (None, None, None, None)
......@@ -1114,6 +1142,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.softmax_type = softmax_type
ctx.window_size = window_size
ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
......@@ -1224,6 +1253,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size,
ctx.deterministic,
)
......@@ -1287,12 +1317,17 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size,
ctx.deterministic,
)
# if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]:
d_bias = None
if ctx.attn_bias_type not in ["no_bias", "alibi"]:
d_bias = rest[0]
d_softmax_offset = None
if ctx.softmax_type != "vanilla":
d_softmax_offset = rest[1]
return (
None,
None,
......@@ -1306,6 +1341,7 @@ class FusedAttnFunc(torch.autograd.Function):
dq,
dk,
dv,
d_bias,
None,
None,
None,
......@@ -1321,36 +1357,7 @@ class FusedAttnFunc(torch.autograd.Function):
None,
None,
None,
)
# else, return (dqkv, dbias)
return (
None,
None,
None,
None,
None,
None,
None,
None,
None,
dq,
dk,
dv,
rest[0],
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
d_softmax_offset,
)
......@@ -1390,6 +1397,7 @@ class FusedAttention(torch.nn.Module):
attention_type: str = "self",
layer_number: Optional[int] = None,
deterministic: bool = False,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -1402,6 +1410,7 @@ class FusedAttention(torch.nn.Module):
) == "1" and get_device_compute_capability() == (9, 0)
self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic
self.softmax_type = softmax_type
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
......@@ -1453,6 +1462,7 @@ class FusedAttention(torch.nn.Module):
quantizers=None,
pad_between_seqs: bool = False,
inference_params: Optional[InferenceParams] = None,
softmax_offset: torch.Tensor = None,
) -> torch.Tensor:
"""fused attention fprop"""
assert (
......@@ -1603,6 +1613,8 @@ class FusedAttention(torch.nn.Module):
fp8_meta=fp8_meta,
quantizers=quantizers,
pad_between_seqs=pad_between_seqs,
softmax_type=self.softmax_type,
softmax_offset=softmax_offset,
)
else:
with self.attention_dropout_ctx():
......@@ -1626,6 +1638,7 @@ class FusedAttention(torch.nn.Module):
qkv_layout,
core_attention_bias_type,
attn_mask_type,
self.softmax_type,
window_size,
None, # rng_gen
fused_attention_backend,
......@@ -1634,6 +1647,7 @@ class FusedAttention(torch.nn.Module):
fp8_meta,
quantizers,
self.deterministic,
softmax_offset,
)
# ...hd -> ...(hd)
......
......@@ -46,6 +46,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
_cu_seqlens_info_with_cp_cache = {}
_seq_chunk_ids_cache_for_reordering_before_attn = {}
_seq_chunk_ids_cache_for_reordering_after_attn = {}
_softmax_offset_chunk_ids_cache = {}
def flash_attn_p2p_communicate(
......@@ -318,6 +319,55 @@ def flash_attn_a2a_communicate(
return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs
def flash_attn_a2a_communicate_softmax_offset(
tensor: torch.Tensor,
h_dim: int,
cp_size: int,
cp_group: dist_group_type,
cp_stream: torch.cuda.Stream,
before_attn: bool,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Split/AllGather communication for softmax offset."""
if tensor is None:
return None
global _softmax_offset_chunk_ids_cache
device = tensor.device
if (cp_size, device) not in _softmax_offset_chunk_ids_cache:
chunk_ids = torch.arange(cp_size, dtype=torch.int32, device=device)
_softmax_offset_chunk_ids_cache[(cp_size, device)] = chunk_ids
else:
chunk_ids = _softmax_offset_chunk_ids_cache[(cp_size, device)]
if before_attn:
# softmax_offset: split round-robin to CP ranks
# [1, h, 1, 1] -> [1, cp, h//cp, 1, 1]
shape = tensor.shape
tensor = tensor.view(
*shape[:h_dim], cp_size, shape[h_dim] // cp_size, *shape[(h_dim + 1) :]
)
rank = get_distributed_rank(cp_group)
output = torch.index_select(tensor, dim=h_dim, index=chunk_ids[rank])
output = output.view(*shape[:h_dim], -1, *shape[(h_dim + 1) :])
else:
# d_softmax_offset: all-gather from all ranks to all ranks
# [1, h//cp, 1, 1] -> [1, h, 1, 1]
inp = tensor.view(-1)
output = torch.empty(cp_size * inp.shape[0], dtype=tensor.dtype, device=device)
with torch.cuda.stream(cp_stream):
torch.distributed.all_gather_into_tensor(
output,
inp,
group=cp_group,
async_op=False,
)
torch.cuda.current_stream().wait_stream(cp_stream)
output = output.view(
*tensor.shape[:h_dim], cp_size * tensor.shape[h_dim], *tensor.shape[h_dim + 1 :]
)
return output
def _get_cu_seqlens_info_with_cp(
batch_size: int,
max_seqlen: int,
......@@ -1854,7 +1904,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
cu_seqlens_q_per_step[cp_size - i - 1],
......@@ -2014,7 +2064,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv // 2,
cu_seqlens_q_per_step[cp_size - i - 1],
......@@ -2171,7 +2221,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd(
ctx.max_seqlen_q // 2,
ctx.max_seqlen_kv,
cu_seqlens_q_per_step[cp_size - i - 1],
......@@ -2289,7 +2339,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
cu_seqlens_q_per_step[cp_size - i - 1],
......@@ -3122,7 +3172,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape)
if ctx.use_fused_attention:
aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]]
dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd(
dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd(
ctx.max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q,
......@@ -3283,6 +3333,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cp_stream,
quantizers,
use_flash_attn_3,
softmax_type,
softmax_offset,
):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
......@@ -3391,6 +3443,10 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
q, k, v = flash_attn_a2a_communicate(
[q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
)
if softmax_type != "vanilla":
softmax_offset = flash_attn_a2a_communicate_softmax_offset(
softmax_offset, 1, cp_size, cp_group, cp_stream, True
)
if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_f16, k_f16, v_f16 = q, k, v
......@@ -3430,6 +3486,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
window_size=window_size,
**fp8_meta_kwargs,
softmax_type=softmax_type,
softmax_offset=softmax_offset,
)
if fp8:
out = out._data
......@@ -3532,6 +3590,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
ctx.use_flash_attn_3 = use_flash_attn_3
ctx.softmax_type = softmax_type
ctx.qkv_dtype = qkv_dtype
ctx.dQKV_quantizer = dQKV_quantizer
......@@ -3695,7 +3754,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
dout_part, fake_dtype=dout_dtype, internal=True
)
dq, dk, dv, _ = fused_attn_bwd(
dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
cu_seqlens_q,
......@@ -3719,6 +3778,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
window_size=ctx.window_size,
deterministic=ctx.deterministic,
**fp8_meta_kwargs,
softmax_type=ctx.softmax_type,
)
if ctx.fp8:
dq = dq._data
......@@ -3763,6 +3823,17 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
elif ctx.qkv_format == "sbhd":
dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]
d_bias = None
d_softmax_offset = None
if ctx.use_fused_attention:
if ctx.attn_bias_type not in ["no_bias", "alibi"]:
d_bias = rest[0]
if ctx.softmax_type != "vanilla":
d_softmax_offset = rest[1]
d_softmax_offset = flash_attn_a2a_communicate_softmax_offset(
d_softmax_offset, 1, cp_size, ctx.cp_group, ctx.cp_stream, False
)
if ctx.fp8:
dq = ctx.dQKV_quantizer.create_tensor_from_data(
dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
......@@ -3793,6 +3864,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
None,
None,
None,
d_bias,
None,
None,
None,
......@@ -3803,6 +3875,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
None,
None,
None,
d_softmax_offset,
)
......@@ -3835,6 +3908,8 @@ def attn_forward_func_with_cp(
quantizers=None,
pad_between_seqs=False,
use_flash_attn_3=False,
softmax_type="vanilla",
softmax_offset=None,
) -> torch.Tensor:
"""
Attention implementation with context parallelism (CP). CP partitions tensors along the sequence
......@@ -3911,23 +3986,23 @@ def attn_forward_func_with_cp(
else:
assert isinstance(
cp_group, dist_group_type
), f"Unsupported process group for CP communication type {cp_comm_type}!"
), f"cp_group must be {dist_group_type} type for {cp_comm_type=}!"
assert qkv_format in [
"bshd",
"sbhd",
"thd",
], f"QKV format of {qkv_format} is not supported with context parallelism!"
], f"Context parallelism does not support {qkv_format=}!"
assert (
qkv_format != "sbhd" or use_fused_attention
), "FlashAttention does not support sbhd format!"
), "Context parallelism does not support FlashAttention backend with qkv_format = 'sbhd'!"
assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
"""Attention bias is only supported with FusedAttention and "causal" """
"""or "no_mask" mask types!"""
"Context parallelism only supports attention bias with FusedAttention backend and"
" non-padding mask types!"
)
assert qkv_format != "thd" or (
cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
), "cu_seqlens_padded cannot be None with context parallelism + THD format!"
), "cu_seqlens_padded can not be None for context parallelism and qkv_format = 'thd'!"
sliding_window_attn = (
window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
......@@ -3935,13 +4010,28 @@ def attn_forward_func_with_cp(
assert not sliding_window_attn or cp_comm_type in [
"a2a",
"all_gather",
], "The context parallel running configs cannot support sliding window attetnion!"
], "Context parallelism does not support sliding window attention with {cp_comm_type=}!"
enable_mla = k.shape[-1] != v.shape[-1]
assert not enable_mla or cp_comm_type in [
"p2p",
"a2a+p2p",
], "The context parallel running configs cannot support MLA!"
], "Context parallelism does not support MLA with {cp_comm_type=}!"
if fp8 and fp8_meta is not None:
if fp8_meta["recipe"].fp8_dpa:
assert (
softmax_type == "vanilla"
), "Context parallelism does not support {softmax_type=} with FP8 attention!"
assert (
softmax_type == "vanilla" or use_fused_attention
), "Context parallelism only supports {softmax_type=} with FusedAttention backend!"
assert (
softmax_type == "vanilla" or cp_comm_type == "a2a"
), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!"
assert (
softmax_type == "vanilla" or qkv_format != "thd"
), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!"
args = [
is_training,
......@@ -3982,7 +4072,17 @@ def attn_forward_func_with_cp(
args += [window_size, cp_group, cp_stream, use_flash_attn_3]
out = AttnFuncWithCPAndKVAllGather.apply(*args)
elif cp_comm_type == "a2a":
args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_flash_attn_3]
args += [
window_size,
fp8,
fp8_meta,
cp_group,
cp_stream,
quantizers,
use_flash_attn_3,
softmax_type,
softmax_offset,
]
out = AttnFuncWithCPAndQKVOA2A.apply(*args)
else:
raise ValueError(f"Unsupported communication type: {cp_comm_type}!")
......
......@@ -11,6 +11,7 @@ import warnings
import logging
import torch
from torch.nn.parameter import Parameter
import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import get_cudnn_version
......@@ -168,6 +169,17 @@ class DotProductAttention(TransformerEngineBaseModule):
softmax_scale: Optional[float], default = `None`
softmax scale for the attention scores. If `None`, defaults to
`1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters
----------------------
......@@ -223,6 +235,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p",
softmax_scale: Optional[float] = None,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -307,6 +320,20 @@ class DotProductAttention(TransformerEngineBaseModule):
self.attention_type = attention_type
self.attention_dropout = attention_dropout
self.softmax_type = softmax_type
if self.softmax_type == "vanilla":
self.softmax_offset = None
if self.softmax_type == "off-by-one":
self.softmax_offset = torch.zeros(
self.num_attention_heads // self.tp_size, device="cuda"
)
if self.softmax_type == "learnable":
self.register_parameter(
"softmax_offset",
Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="cuda")),
get_rng_state_tracker=get_rng_state_tracker,
)
attn_kwargs = {
"attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx,
......@@ -328,6 +355,7 @@ class DotProductAttention(TransformerEngineBaseModule):
layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs,
softmax_type=self.softmax_type,
)
self.unfused_attention = UnfusedDotProductAttention(
......@@ -335,6 +363,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_type=attention_type,
**attn_kwargs,
layer_number=layer_number,
softmax_type=self.softmax_type,
)
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
......@@ -634,6 +663,7 @@ class DotProductAttention(TransformerEngineBaseModule):
query_layer,
num_gemms=3,
allow_non_contiguous=True,
allow_different_data_and_param_types=self.softmax_type != "vanilla",
) as query_layer:
# checks for RNG
if self.rng_states_tracker is not None and is_graph_capturing():
......@@ -922,6 +952,7 @@ class DotProductAttention(TransformerEngineBaseModule):
False
), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"
# check if there is padding between sequences when qkv_format='thd'
if pad_between_seqs is None:
if qkv_format == "thd":
pad_between_seqs = (
......@@ -957,11 +988,13 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs=pad_between_seqs,
attention_dropout=self.attention_dropout,
context_parallel=context_parallel,
cp_comm_type=self.cp_comm_type,
deterministic=self.deterministic,
is_training=self.training,
fp8=self.fp8,
fp8_meta=self.fp8_meta,
inference_params=inference_params,
softmax_type=self.softmax_type,
)
global _attention_backends
if is_in_onnx_export_mode():
......@@ -1022,6 +1055,12 @@ class DotProductAttention(TransformerEngineBaseModule):
)
# run attention
softmax_offset = (
self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32)
if self.softmax_offset is not None
else None
)
if use_flash_attention:
if core_attention_bias_type == "alibi":
alibi_slopes, _ = dpa_utils.get_alibi(
......@@ -1071,7 +1110,6 @@ class DotProductAttention(TransformerEngineBaseModule):
bias_dtype=query_layer.dtype,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
)
# checkpoint_core_attention=False
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.fused_attention,
......@@ -1101,6 +1139,7 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs,
inference_params=inference_params,
softmax_offset=softmax_offset,
)
return self.fused_attention(
query_layer,
......@@ -1129,6 +1168,7 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs,
inference_params=inference_params,
softmax_offset=softmax_offset,
)
from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled
......@@ -1157,6 +1197,7 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
inference_params=inference_params,
softmax_offset=softmax_offset,
)
return self.unfused_attention(
_alibi_cache,
......@@ -1173,5 +1214,6 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
inference_params=inference_params,
softmax_offset=softmax_offset,
)
return None
......@@ -24,6 +24,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
QKVLayout,
AttnBiasType,
AttnMaskType,
SoftmaxType,
FusedAttnBackend,
META_QKV,
META_DQKV,
......@@ -206,6 +207,8 @@ class AttentionParams:
Attention dropout.
context_parallel: bool, default = `False`
Whether context parallelism is used or not.
cp_comm_type: str, default = "p2p"
The communication type of context parallelism.
deterministic: bool, default = `False`
Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True`
......@@ -216,6 +219,8 @@ class AttentionParams:
The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None`
Inference-related parameters. See InferenceParams for details.
softmax_type: str, default = "vanilla"
The type of softmax operation. See DotProductAttention for details.
"""
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
......@@ -237,11 +242,13 @@ class AttentionParams:
pad_between_seqs: bool = False
attention_dropout: float = 0.0
context_parallel: bool = False
cp_comm_type: str = "p2p"
deterministic: bool = False
is_training: bool = True
fp8: bool = False
fp8_meta: Union[Dict[str, Any], None] = None
inference_params: Optional[InferenceParams] = None
softmax_type: str = "vanilla"
def __eq__(self, other):
"""
......@@ -308,11 +315,13 @@ def get_attention_backend(
pad_between_seqs = attention_params.pad_between_seqs
attention_dropout = attention_params.attention_dropout
context_parallel = attention_params.context_parallel
cp_comm_type = attention_params.cp_comm_type
deterministic = attention_params.deterministic
is_training = attention_params.is_training
fp8 = attention_params.fp8
fp8_meta = attention_params.fp8_meta
inference_params = attention_params.inference_params
softmax_type = attention_params.softmax_type
# Run config
logger = logging.getLogger("DotProductAttention")
......@@ -565,6 +574,51 @@ def get_attention_backend(
logger.debug("Disabling FlashAttention 3 for dropout")
use_flash_attention_3 = False
# Filter: Softmax type
# context_parallel | softmax_type | supported backends
# ----------------------------------------------------------------------------------------------------
# no | vanilla | All
# no | off-by-one | FusedAttention, UnfusedDotProductAttention
# no | learnable | FusedAttention, UnfusedDotProductAttention
# yes | vanilla | FusedAttention, FlashAttention
# yes | off-by-one | FusedAttention
# yes | learnable | FusedAttention
if softmax_type != "vanilla":
logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type)
use_flash_attention = False
if fp8 and fp8_meta["recipe"].fp8_dpa:
logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type
)
use_unfused_attention = False
if qkv_format == "thd":
logger.debug(
"Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type
)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd",
softmax_type,
)
use_unfused_attention = False
if context_parallel:
logger.debug(
"Disabling UnfusedDotProductAttention for context parallelism with softmax_type"
" = %s",
softmax_type,
)
use_unfused_attention = False
if cp_comm_type != "a2a":
logger.debug(
"Disabling FusedAttention for context parallelism with softmax_type = %s and"
" cp_comm_type = %s",
softmax_type,
cp_comm_type,
)
use_fused_attention = False
# Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends
# ----------------------------------------------------------------------------------------------------
......@@ -806,6 +860,7 @@ def get_attention_backend(
QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
attention_dropout,
num_heads,
num_gqa_groups,
......
......@@ -135,6 +135,17 @@ class MultiheadAttention(torch.nn.Module):
For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None`
name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters
----------------------
......@@ -245,6 +256,7 @@ class MultiheadAttention(torch.nn.Module):
qk_norm_before_rope: bool = False,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -262,6 +274,7 @@ class MultiheadAttention(torch.nn.Module):
self.return_bias = return_bias
self.cp_size = 1
self.cp_rank = 0
self.softmax_type = softmax_type
kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)
......@@ -416,6 +429,7 @@ class MultiheadAttention(torch.nn.Module):
tp_group=tp_group,
layer_number=self.layer_number,
attention_type=self.attention_type,
softmax_type=self.softmax_type,
)
# Linear
......
......@@ -12,6 +12,7 @@ from transformer_engine_torch import (
NVTE_QKV_Format,
NVTE_Bias_Type,
NVTE_Mask_Type,
NVTE_Softmax_Type,
NVTE_Fused_Attn_Backend,
)
from ..tensor.quantized_tensor import Quantizer
......@@ -86,6 +87,12 @@ AttnMaskType = {
"padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
}
SoftmaxType = {
"vanilla": NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX,
"off-by-one": NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX,
"learnable": NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX,
}
FusedAttnBackend = {
"F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen,
"F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
......@@ -131,8 +138,10 @@ def fused_attn_fwd(
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None,
softmax_offset: torch.Tensor = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for separate QKV input.
......@@ -197,6 +206,8 @@ def fused_attn_fwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
......@@ -205,6 +216,9 @@ def fused_attn_fwd(
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
softmax_offset: torch.Tensor, default = None
softmax offset tensor in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
Returns
----------
......@@ -286,6 +300,7 @@ def fused_attn_fwd(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size,
cu_seqlens_q,
cu_seqlens_kv,
......@@ -300,6 +315,7 @@ def fused_attn_fwd(
s_quantizer,
o_quantizer,
attn_bias,
softmax_offset,
rng_gen,
rng_elts_per_thread,
)
......@@ -333,6 +349,7 @@ def fused_attn_bwd(
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
......@@ -398,6 +415,8 @@ def fused_attn_bwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
......@@ -417,6 +436,9 @@ def fused_attn_bwd(
d_bias: torch.Tensor, optional
gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
or "post_scale_bias"; same data type and shape as Bias
d_softmax_offset: torch.Tensor, optional
gradient tensor of softmax offset in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
"""
if attn_scale is None:
d = q.size(-1)
......@@ -454,6 +476,7 @@ def fused_attn_bwd(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size,
deterministic,
cu_seqlens_q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment