Unverified Commit 5e4e0b2c authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Add sink attention support from cuDNN (#2148)



* first draft; debug plan failure
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* debug uid error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak params
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add grad in output
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up prints
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix prints in test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* address review comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix unfused grad; add softmax_type; add sink to bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix padding mask; add swa tests; remove requires_grad for off-by-one
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix indent
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix non-determinism and shapes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up prints
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add GQA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add CP A2A; dq/dk mismatches
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CP A2A; need cleaner solution
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CP A2A; pending cudnn kernel change
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix world size in unit test; avoid thd format
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix kernel_backend, dtype in unit test; fix head_dim for FP8 Hopper
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix thd logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 context
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak CP logging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* allow no_mask/padding for SWA(left,0)
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "allow no_mask/padding for SWA(left,0)"

This reverts commit 08b4ccc67a08b6882080b06aa715f541bb832aca.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add softmax_type to Jax
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add cuDNN version control
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* prettify tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* skip 9.13 for MLA, non 192/128
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename compare_with_error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* small cleanups and improvements
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix minor CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force sink/dsink to be float32
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* switch FE to GH FE
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* return to GH TE main FE commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FE to 1.14.1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up before CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* bump up cudnn version
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add backend selection guard for unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add docstring for softmax type enums in C
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 57b4d7bc
Subproject commit deda80e5372d50e925d7bf4f76c5db779be3fbd5 Subproject commit 1a7b4b78db44712fb9707d21cd2e3179f1fd88b8
...@@ -17,88 +17,18 @@ from test_attention_with_cp import model_configs_flash_attn, model_configs_fused ...@@ -17,88 +17,18 @@ from test_attention_with_cp import model_configs_flash_attn, model_configs_fused
from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
from utils import ModelConfig, compare_and_assert
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
def run_dpa_with_cp( def generate_input_shapes(
dtype="bf16", qkv_format: str,
model=None, config: ModelConfig,
qkv_format="bshd", world_size: int,
kernel_backend="FlashAttention", kernel_backend: str,
cp_comm_type="p2p",
fp8_mha=False,
): ):
"""Test DotProductAttention module with context parallelism"""
# args are passed as strings
fp8_mha = fp8_mha == "True"
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)
print(f"[INFO] world_size:{world_size}, rank:{rank}")
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
# create flash attn comm group for CP
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert (
world_size % 2 == 0
), "Assuming CP size for A2A is 2, and CP size for P2P is (world_size // 2)!"
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)
if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha)
# instantiate core attn module
core_attn = DotProductAttention(
config.num_heads,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
)
core_attn = core_attn.cuda()
# create flash attn inputs
if qkv_format == "bshd": if qkv_format == "bshd":
q_input_shape = ( q_input_shape = (
config.batch_size, config.batch_size,
...@@ -191,34 +121,158 @@ def run_dpa_with_cp( ...@@ -191,34 +121,158 @@ def run_dpa_with_cp(
cu_seqlens_kv = cu_seqlens_q cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded cu_seqlens_kv_padded = cu_seqlens_q_padded
else: else:
assert False, f"{qkv_format} is an unsupported qkv_format!" assert False, f"{qkv_format=} is not supported!"
return (
q_input_shape,
k_input_shape,
v_input_shape,
attn_output_shape,
cu_seqlens_q,
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
)
def get_tols(config, dtype):
if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
atol = 2.5e-2
rtol = 2.5e-2
else:
atol = 3.5e-2
rtol = 3.5e-2
rmse_tol = 0.01
elif dtype == "fp16":
atol = 5e-3
rtol = 5e-3
rmse_tol = 0.01
elif dtype == "fp8":
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.1
else:
assert False, f"{dtype=} is not supported!"
return atol, rtol, rmse_tol
def run_dpa_with_cp(
dtype="bf16",
model=None,
qkv_format="bshd",
kernel_backend="FlashAttention",
cp_comm_type="p2p",
fp8_mha=False,
log_level=logging.WARNING,
):
"""Test DotProductAttention module with context parallelism"""
logging.root.setLevel(log_level)
# set up environment variables and config
fp8_mha = fp8_mha == "True"
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type=} is not supported!"
if qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
# set up distributed group
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)
logging.info(f"[Rank {rank}] Setup: world_size {world_size}")
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
# set up communication group for CP
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert world_size % 2 == 0, (
"{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size"
" = 2."
)
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)
if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha)
# instantiate attention module
core_attn = DotProductAttention(
config.num_heads,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
softmax_type=config.softmax_type,
).cuda()
if config.softmax_type != "vanilla":
core_attn.softmax_offset.requires_grad = True
# generate attention inputs
(
q_input_shape,
k_input_shape,
v_input_shape,
attn_output_shape,
cu_seqlens_q,
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend)
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda() q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda() k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda() v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda()
for x in [q, k, v]:
x.requires_grad = True
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda() dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
dout_quantizer = Float8Quantizer( if fp8_mha:
fp8_dtype=tex.DType.kFloat8E5M2, dout_quantizer = Float8Quantizer(
scale=torch.tensor([1], dtype=torch.float32).cuda(), fp8_dtype=tex.DType.kFloat8E5M2,
amax=torch.tensor([0], dtype=torch.float32).cuda(), scale=torch.tensor([1], dtype=torch.float32).cuda(),
) amax=torch.tensor([0], dtype=torch.float32).cuda(),
)
# create flash attention bias
if config.attn_bias_type not in ["no_bias", "alibi"]: if config.attn_bias_type not in ["no_bias", "alibi"]:
attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv)
bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda()
else: else:
bias = None bias = None
# run core_attn without CP ############ run without CP ############
for x in [q, k, v]: logging.info(f"[Rank {rank}] Run without context parallelism")
x.requires_grad = True
if dtype == "fp8": if dtype == "fp8":
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else: else:
fp8_context = nullcontext() fp8_context = nullcontext()
with fp8_context: with fp8_context:
out = core_attn( out = core_attn(
q, q,
...@@ -236,8 +290,30 @@ def run_dpa_with_cp( ...@@ -236,8 +290,30 @@ def run_dpa_with_cp(
out.backward(dout_fp8) out.backward(dout_fp8)
else: else:
out.backward(dout) out.backward(dout)
dq, dk, dv = q.grad, k.grad, v.grad
d_softmax_offset = None
if config.softmax_type != "vanilla":
d_softmax_offset = core_attn.softmax_offset.grad
# run core_attn wit CP ############ run with CP ############
logging.info(f"[Rank {rank}] Run with context parallelism")
# set up environment
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)
if config.softmax_type != "vanilla":
core_attn.softmax_offset.grad.zero_()
if dtype == "fp8":
core_attn.reset_fp8_meta_tensors()
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
# set up inputs
q_, k_, v_, dout_, *rest = [ q_, k_, v_, dout_, *rest = [
x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias]) x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])
] ]
...@@ -267,8 +343,6 @@ def run_dpa_with_cp( ...@@ -267,8 +343,6 @@ def run_dpa_with_cp(
) )
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None: if bias_ is not None:
bias_ = bias_.view( bias_ = bias_.view(
...@@ -276,19 +350,8 @@ def run_dpa_with_cp( ...@@ -276,19 +350,8 @@ def run_dpa_with_cp(
) )
bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)
if dtype == "fp8":
core_attn.reset_fp8_meta_tensors()
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
# run attention
with fp8_context: with fp8_context:
out_ = core_attn( out_ = core_attn(
q_, q_,
...@@ -306,18 +369,23 @@ def run_dpa_with_cp( ...@@ -306,18 +369,23 @@ def run_dpa_with_cp(
out_.backward(dout_fp8_) out_.backward(dout_fp8_)
else: else:
out_.backward(dout_) out_.backward(dout_)
if fp8_mha: if fp8_mha:
assert isinstance(out, Float8Tensor) assert isinstance(out, Float8Tensor)
assert isinstance(out_, Float8Tensor) assert isinstance(out_, Float8Tensor)
out = out.dequantize() out = out.dequantize()
out_ = out_.dequantize() out_ = out_.dequantize()
for x in [out_, q_.grad, k_.grad, v_.grad]: # get outputs
assert torch.all(~torch.isnan(x)) dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad
assert torch.all(~torch.isinf(x)) d_softmax_offset_ = None
if config.softmax_type != "vanilla":
# compare results with and without CP d_softmax_offset_ = core_attn.softmax_offset.grad.clone()
for x in [out_, dq_, dk_, dv_, d_softmax_offset_]:
if x is not None:
assert torch.all(~torch.isnan(x))
assert torch.all(~torch.isinf(x))
############ compare results between CP and no-CP ############
if qkv_format == "bshd" or qkv_format == "sbhd": if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [ dq, dk, dv, out = [
x.view( x.view(
...@@ -373,56 +441,70 @@ def run_dpa_with_cp( ...@@ -373,56 +441,70 @@ def run_dpa_with_cp(
).item() ).item()
== 0 == 0
) )
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=3.5e-2, rtol=3.5e-2)
elif dtype == "fp16":
tols = dict(atol=5e-3, rtol=5e-3)
elif dtype == "fp8":
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
else:
assert False, f"{dtype} is an unsupported dtype!"
def _rmse(a, b):
return torch.sqrt((a - b).square().mean()).item()
def _error(a, b):
if dtype != "fp8":
torch.testing.assert_close(a, b, **tols)
else:
try:
torch.testing.assert_close(a, b, **tols)
except Exception as e:
logging.debug(e)
rmse = _rmse(a, b)
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert (
rmse < rmse_tol * rmse_range
), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
if qkv_format == "bshd": atol, rtol, rmse_tol = get_tols(config, dtype)
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_]
_error(a[:, 0], b[:, 0]) tensors_no_cp = [out, dq, dk, dv, d_softmax_offset]
_error(a[:, 1], b[:, 1]) names = ["out", "dq", "dk", "dv", "d_softmax_offset"]
elif qkv_format == "sbhd": names_cp = [x + "_cp" for x in names]
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): names_no_cp = [x + "_no_cp" for x in names]
_error(a[0], b[0]) is_fp8 = dtype == "fp8"
_error(a[1], b[1]) for i, t in enumerate(tensors_no_cp):
elif qkv_format == "thd": if t is not None:
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): if "softmax_offset" not in names[i]:
_error(a, b) if qkv_format == "bshd":
else: compare_and_assert(
assert False, f"{qkv_format} is an unsupported qkv_format!" t[:, 0],
tensors_cp[i][:, 0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[:, 1],
tensors_cp[i][:, 1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
elif qkv_format == "sbhd":
compare_and_assert(
t[0],
tensors_cp[i][0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[1],
tensors_cp[i][1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
elif qkv_format == "thd":
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
)
else:
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
)
logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches")
# destroy distribution group
dist.destroy_process_group() dist.destroy_process_group()
......
This diff is collapsed.
...@@ -6,6 +6,7 @@ import os ...@@ -6,6 +6,7 @@ import os
import subprocess import subprocess
import sys import sys
import pathlib import pathlib
import logging
import pytest import pytest
import torch import torch
...@@ -19,13 +20,15 @@ _current_file = pathlib.Path(__file__).resolve() ...@@ -19,13 +20,15 @@ _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent)) sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig, get_available_attention_backends from utils import ModelConfig, get_available_attention_backends
pytest_logging_level = logging.getLevelName(logging.root.level)
# Initialize RNG state # Initialize RNG state
seed = 1234 seed = 1234
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
model_configs_flash_attn = { model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
...@@ -72,6 +75,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -72,6 +75,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")
config = model_configs_flash_attn[model] config = model_configs_flash_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!") pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd": if cp_comm_type == "all_gather" and qkv_format == "thd":
...@@ -89,6 +94,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -89,6 +94,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
) )
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!") pytest.skip("MLA CP currently only support KV P2P!")
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16}
available_backends, *_ = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype],
qkv_layout="_".join([qkv_format] * 3),
)
flash_attn_supported, *_ = available_backends
if not flash_attn_supported:
pytest.skip("No attention backend available.")
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
...@@ -98,13 +112,14 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -98,13 +112,14 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
qkv_format=qkv_format, qkv_format=qkv_format,
kernel_backend="FlashAttention", kernel_backend="FlashAttention",
cp_comm_type=cp_comm_type, cp_comm_type=cp_comm_type,
log_level=pytest_logging_level,
), ),
check=True, check=True,
) )
model_configs_fused_attn = { model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig( "cp_1_2": ModelConfig(
...@@ -135,6 +150,15 @@ model_configs_fused_attn = { ...@@ -135,6 +150,15 @@ model_configs_fused_attn = {
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
), # MLA ), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA "cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
"cp_4_0": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla"
), # GQA
"cp_4_1": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
), # GQA
"cp_4_2": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
), # GQA
} }
...@@ -158,6 +182,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -158,6 +182,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("FP8 attention is only supported on sm90+!") pytest.skip("FP8 attention is only supported on sm90+!")
config = model_configs_fused_attn[model] config = model_configs_fused_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!") pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather": if qkv_format == "thd" and cp_comm_type == "all_gather":
...@@ -191,13 +217,22 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -191,13 +217,22 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("MLA CP currently only support KV P2P!") pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!") pytest.skip("MLA CP currently does not support FP8 attention!")
if dtype == "fp8" and config.softmax_type != "vanilla":
pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!")
if config.softmax_type != "vanilla" and cp_comm_type != "a2a":
pytest.skip(
"CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!"
)
if config.softmax_type != "vanilla" and qkv_format == "thd":
pytest.skip(
"CP implementation does not support qkv_format=thd for non-vanilla softmax types!"
)
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
available_backends, _, fused_attn_backends = get_available_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtypes[dtype], qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn,
qkv_layout="_".join([qkv_format] * 3), qkv_layout="_".join([qkv_format] * 3),
window_size=config.window_size,
context_parallel=True,
) )
_, fused_attn_supported, _ = available_backends _, fused_attn_supported, _ = available_backends
if not fused_attn_supported: if not fused_attn_supported:
...@@ -212,6 +247,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -212,6 +247,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
kernel_backend="FusedAttention", kernel_backend="FusedAttention",
cp_comm_type=cp_comm_type, cp_comm_type=cp_comm_type,
fp8_mha=fp8_mha, fp8_mha=fp8_mha,
log_level=pytest_logging_level,
), ),
check=True, check=True,
) )
...@@ -469,7 +469,6 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g ...@@ -469,7 +469,6 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=False, pad_between_seqs=False,
is_training=False, is_training=False,
fp8=is_fp8, fp8=is_fp8,
......
...@@ -20,6 +20,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import ( ...@@ -20,6 +20,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
get_attention_backend, get_attention_backend,
AttentionParams, AttentionParams,
AttentionLogging, AttentionLogging,
check_set_window_size,
) )
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
...@@ -137,6 +138,31 @@ def reset_rng_states() -> None: ...@@ -137,6 +138,31 @@ def reset_rng_states() -> None:
torch.cuda.set_rng_state(cuda_rng_state) torch.cuda.set_rng_state(cuda_rng_state)
def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8):
if not is_fp8:
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
return
try:
if a.dtype != b.dtype:
a = a.to(b.dtype)
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
rmse = torch.sqrt((a - b).square().mean()).item()
logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert rmse < rmse_tol * rmse_range, (
name_a
+ " vs "
+ name_b
+ " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
)
class ModelConfig: class ModelConfig:
def __init__( def __init__(
self, self,
...@@ -147,12 +173,15 @@ class ModelConfig: ...@@ -147,12 +173,15 @@ class ModelConfig:
max_seqlen_kv: int = None, max_seqlen_kv: int = None,
num_gqa_groups: int = None, num_gqa_groups: int = None,
head_dim_v: int = None, head_dim_v: int = None,
softmax_type: str = "vanilla",
dropout_p: float = 0.0, dropout_p: float = 0.0,
attn_mask_type: str = "no_mask", attn_mask_type: str = "no_mask",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
alibi_type: str = "none", alibi_type: str = "none",
bias_shape: str = "1hss", bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
context_parallel: bool = False,
cp_comm_type: str = "p2p",
total_requests: int = None, total_requests: int = None,
max_ctx_len: int = None, max_ctx_len: int = None,
num_layers: int = 1, num_layers: int = 1,
...@@ -171,13 +200,16 @@ class ModelConfig: ...@@ -171,13 +200,16 @@ class ModelConfig:
self.kv_channels = (self.head_dim_qk, self.head_dim_v) self.kv_channels = (self.head_dim_qk, self.head_dim_v)
self.hidden_size = self.num_heads * self.head_dim_qk self.hidden_size = self.num_heads * self.head_dim_qk
self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v
self.softmax_type = softmax_type
self.dropout_p = dropout_p self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type self.alibi_type = alibi_type
self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross" self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross"
self.bias_shape = bias_shape self.bias_shape = bias_shape
self.window_size = window_size self.window_size = check_set_window_size(self.attn_mask_type, window_size)
self.context_parallel = context_parallel
self.cp_comm_type = cp_comm_type
self.total_requests = total_requests self.total_requests = total_requests
self.max_ctx_len = max_ctx_len self.max_ctx_len = max_ctx_len
self.num_layers = num_layers self.num_layers = num_layers
...@@ -198,9 +230,7 @@ def get_available_attention_backends( ...@@ -198,9 +230,7 @@ def get_available_attention_backends(
config: ModelConfig, config: ModelConfig,
qkv_dtype: torch.dtype, qkv_dtype: torch.dtype,
qkv_layout: str, qkv_layout: str,
window_size: Tuple[int, int] = (-1, -1),
pad_between_seqs: bool = False, pad_between_seqs: bool = False,
context_parallel: bool = False,
deterministic: bool = False, deterministic: bool = False,
fp8: bool = False, fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None, fp8_meta: Optional[Dict[str, Any]] = None,
...@@ -250,19 +280,21 @@ def get_available_attention_backends( ...@@ -250,19 +280,21 @@ def get_available_attention_backends(
head_dim_qk=config.head_dim_qk, head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v, head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type, attn_mask_type=config.attn_mask_type,
window_size=window_size, window_size=config.window_size,
alibi_slopes_shape=alibi_slopes_shape, alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape, core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad, core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p, attention_dropout=config.dropout_p,
context_parallel=context_parallel, context_parallel=config.context_parallel,
cp_comm_type=config.cp_comm_type,
deterministic=deterministic, deterministic=deterministic,
fp8=fp8, fp8=fp8,
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
is_training=is_training, is_training=is_training,
inference_params=inference_params, inference_params=inference_params,
softmax_type=config.softmax_type,
) )
( (
use_flash_attention, use_flash_attention,
......
...@@ -21,17 +21,19 @@ namespace transformer_engine { ...@@ -21,17 +21,19 @@ namespace transformer_engine {
void fused_attn_arbitrary_seqlen_fwd_qkvpacked( void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
...@@ -41,21 +43,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -41,21 +43,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle); cudaStream_t stream, cudnnHandle_t handle);
...@@ -66,24 +69,26 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -66,24 +69,26 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd( void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
cudaStream_t stream, cudnnHandle_t handle); const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900 #endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -1695,6 +1695,7 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1695,6 +1695,7 @@ void fused_attn_fp8_fwd_impl_v1(
layout, layout,
bias_type, bias_type,
mask_type, mask_type,
NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0, 0,
0, 0,
true, true,
...@@ -2000,6 +2001,7 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -2000,6 +2001,7 @@ void fused_attn_fp8_bwd_impl_v1(
layout, layout,
bias_type, bias_type,
mask_type, mask_type,
NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0, 0,
0, 0,
false, false,
......
...@@ -107,6 +107,7 @@ struct FADescriptor_v1 { ...@@ -107,6 +107,7 @@ struct FADescriptor_v1 {
NVTE_QKV_Layout layout; NVTE_QKV_Layout layout;
NVTE_Bias_Type bias_type; NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type; NVTE_Mask_Type mask_type;
NVTE_Softmax_Type softmax_type;
std::int64_t window_size_left; std::int64_t window_size_left;
std::int64_t window_size_right; std::int64_t window_size_right;
bool deterministic; bool deterministic;
...@@ -116,14 +117,15 @@ struct FADescriptor_v1 { ...@@ -116,14 +117,15 @@ struct FADescriptor_v1 {
bool operator<(const FADescriptor_v1 &rhs) const { bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < window_size_left, window_size_right, deterministic, bias_type, fwd_tensor_type,
bwd_tensor_type) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
rhs.bwd_tensor_type); rhs.fwd_tensor_type, rhs.bwd_tensor_type);
} }
}; };
......
...@@ -124,6 +124,24 @@ enum NVTE_Mask_Type { ...@@ -124,6 +124,24 @@ enum NVTE_Mask_Type {
NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5, NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5,
}; };
/*! \enum NVTE_Softmax_Type
* \brief Attention softmax types as described in
* Efficient Streaming Language Models with Attention Sinks (https://arxiv.org/pdf/2309.17453v3).
* For a given attention score S = Q*K^T, different softmax types perform different operations on S,
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* where alpha is a learnable parameter in shape [H].
*/
enum NVTE_Softmax_Type {
/*! Vanilla softmax */
NVTE_VANILLA_SOFTMAX = 0,
/*! Off-by-one softmax */
NVTE_OFF_BY_ONE_SOFTMAX = 1,
/*! Learnable softmax */
NVTE_LEARNABLE_SOFTMAX = 2,
};
/*! \enum NVTE_Fused_Attn_Backend /*! \enum NVTE_Fused_Attn_Backend
* \brief Fused attention backends * \brief Fused attention backends
*/ */
...@@ -178,6 +196,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); ...@@ -178,6 +196,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] qkv_layout The layout of Tensors Q, K, V. * \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type. * \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type. * \param[in] attn_mask_type The attention mask type.
* \param[in] softmax_type The attention softmax type.
* \param[in] dropout The dropout probability. * \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q. * \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V. * \param[in] num_gqa_groups The number of heads in K, V.
...@@ -190,9 +209,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); ...@@ -190,9 +209,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
*/ */
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right);
/*! \brief Compute dot product attention with packed QKV input. /*! \brief Compute dot product attention with packed QKV input.
* *
...@@ -224,6 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -224,6 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* *
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD. * \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor. * \param[in,out] S The S tensor.
* \param[out] O The output O tensor. * \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
...@@ -239,19 +260,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -239,19 +260,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, void nvte_fused_attn_fwd_qkvpacked(
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor rng_state, size_t max_seqlen, bool is_training, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input. /*! \brief Compute the backward of the dot product attention with packed QKV input.
* *
...@@ -284,6 +305,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -284,6 +305,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* e.g. M, ZInv, rng_state. * e.g. M, ZInv, rng_state.
* \param[out] dQKV The gradient of the QKV tensor. * \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1]. * \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing, * \param[in] max_seqlen Max sequence length used for computing,
...@@ -293,6 +315,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -293,6 +315,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
...@@ -302,10 +325,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -302,10 +325,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens, NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens_padded, size_t max_seqlen, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, size_t max_seqlen, float attn_scale, float dropout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream); bool deterministic, NVTETensor workspace, cudaStream_t stream);
...@@ -340,6 +364,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -340,6 +364,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] Q The Q tensor, in HD layouts. * \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in 2HD or H2D layouts. * \param[in] KV The KV tensor, in 2HD or H2D layouts.
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor. * \param[in,out] S The S tensor.
* \param[out] O The output O tensor. * \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
...@@ -361,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -361,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
...@@ -368,13 +394,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -368,13 +394,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd_kvpacked( void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input. /*! \brief Compute the backward of the dot product attention with packed KV input.
* *
...@@ -409,6 +437,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -409,6 +437,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[out] dQ The gradient of the Q tensor. * \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor. * \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
...@@ -422,6 +451,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -422,6 +451,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
...@@ -431,12 +461,12 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -431,12 +461,12 @@ void nvte_fused_attn_fwd_kvpacked(
void nvte_fused_attn_bwd_kvpacked( void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
cudaStream_t stream); int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V. /*! \brief Compute dot product attention with separate Q, K and V.
* *
...@@ -473,6 +503,7 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -473,6 +503,7 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] K The K tensor. * \param[in] K The K tensor.
* \param[in] V The V tensor. * \param[in] V The V tensor.
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor. * \param[in,out] S The S tensor.
* \param[out] O The output O tensor. * \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
...@@ -494,22 +525,24 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -494,22 +525,24 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] qkv_layout QKV tensors' layout. * \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
cudaStream_t stream); int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V. /*! \brief Compute the backward of the dot product attention with separate Q, K and V.
* *
...@@ -549,6 +582,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -549,6 +582,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[out] dK The gradient of the K tensor. * \param[out] dK The gradient of the K tensor.
* \param[out] dV The gradient of the V tensor. * \param[out] dV The gradient of the V tensor.
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
...@@ -562,6 +596,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -562,6 +596,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[in] qkv_layout QKV tensors' layout. * \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type. * \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type. * \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
...@@ -571,14 +606,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -571,14 +606,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, int64_t window_size_left, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_right, bool deterministic, NVTETensor workspace, int64_t window_size_left, int64_t window_size_right, bool deterministic,
cudaStream_t stream); NVTETensor workspace, cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset. /*! \brief Update the RNG state with the seed and calculated offset.
* *
......
...@@ -36,6 +36,10 @@ ...@@ -36,6 +36,10 @@
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_Softmax_Type>(m, "NVTE_Softmax_Type", pybind11::module_local()) \
.value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) \
.value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX) \
.value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX); \
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \ pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \
......
...@@ -13,6 +13,7 @@ import logging ...@@ -13,6 +13,7 @@ import logging
from packaging.version import Version as PkgVersion from packaging.version import Version as PkgVersion
import torch import torch
import torch.nn.functional as F
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
SplitAlongDim, SplitAlongDim,
...@@ -142,6 +143,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -142,6 +143,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext, attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -149,6 +151,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -149,6 +151,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.attention_type = attention_type self.attention_type = attention_type
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number self.layer_number = layer_number
self.softmax_type = softmax_type
def mask_func(x, y): def mask_func(x, y):
return ( return (
...@@ -185,6 +188,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -185,6 +188,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
softmax_offset: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Unfused attention fprop""" """Unfused attention fprop"""
assert ( assert (
...@@ -326,7 +330,21 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -326,7 +330,21 @@ class UnfusedDotProductAttention(torch.nn.Module):
dtype=query_layer.dtype dtype=query_layer.dtype
) )
# attention scores and attention mask [b, np, sq, sk] # add attention sink to the last column: [b, np, sq, sk+1]
if self.softmax_type != "vanilla":
matmul_result = torch.cat(
[
matmul_result,
softmax_offset.to(dtype=matmul_result.dtype).expand(
matmul_result.size(0), -1, matmul_result.size(2), -1
),
],
dim=-1,
)
attention_mask = F.pad(attention_mask, (0, 1), mode="constant", value=False)
attn_mask_type = "arbitrary"
# attention scores and attention mask
softmax_scale = self.layer_number if apply_qk_layer_scaling else None softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax( attention_probs = self.scale_mask_softmax(
matmul_result, attention_mask, attn_mask_type, softmax_scale matmul_result, attention_mask, attn_mask_type, softmax_scale
...@@ -337,6 +355,10 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -337,6 +355,10 @@ class UnfusedDotProductAttention(torch.nn.Module):
if "padding" in attn_mask_type: if "padding" in attn_mask_type:
attention_probs = attention_probs.masked_fill(attention_mask, 0) attention_probs = attention_probs.masked_fill(attention_mask, 0)
# remove attention sink: [b, np, sq, sk]
if self.softmax_type != "vanilla":
attention_probs = attention_probs[..., :-1]
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
...@@ -917,6 +939,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -917,6 +939,7 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
window_size, window_size,
rng_gen, rng_gen,
fused_attention_backend, fused_attention_backend,
...@@ -925,6 +948,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -925,6 +948,7 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta, fp8_meta,
quantizers, quantizers,
deterministic, deterministic,
softmax_offset,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
...@@ -997,8 +1021,10 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -997,8 +1021,10 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
window_size, window_size,
rng_gen, rng_gen,
softmax_offset,
) )
if is_output_fp8: if is_output_fp8:
out_ret = out_fp8 out_ret = out_fp8
...@@ -1059,8 +1085,10 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1059,8 +1085,10 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
window_size, window_size,
rng_gen, rng_gen,
softmax_offset,
) )
out_save = out_ret out_save = out_ret
fp8_tensors = (None, None, None, None) fp8_tensors = (None, None, None, None)
...@@ -1114,6 +1142,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1114,6 +1142,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout = qkv_layout ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.softmax_type = softmax_type
ctx.window_size = window_size ctx.window_size = window_size
ctx.fused_attention_backend = ( ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
...@@ -1224,6 +1253,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1224,6 +1253,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size, ctx.window_size,
ctx.deterministic, ctx.deterministic,
) )
...@@ -1287,42 +1317,17 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1287,42 +1317,17 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size, ctx.window_size,
ctx.deterministic, ctx.deterministic,
) )
# if no_bias or alibi, return dqkv d_bias = None
if ctx.attn_bias_type in ["no_bias", "alibi"]: if ctx.attn_bias_type not in ["no_bias", "alibi"]:
return ( d_bias = rest[0]
None, d_softmax_offset = None
None, if ctx.softmax_type != "vanilla":
None, d_softmax_offset = rest[1]
None,
None,
None,
None,
None,
None,
dq,
dk,
dv,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
# else, return (dqkv, dbias)
return ( return (
None, None,
None, None,
...@@ -1336,7 +1341,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1336,7 +1341,8 @@ class FusedAttnFunc(torch.autograd.Function):
dq, dq,
dk, dk,
dv, dv,
rest[0], d_bias,
None,
None, None,
None, None,
None, None,
...@@ -1351,6 +1357,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1351,6 +1357,7 @@ class FusedAttnFunc(torch.autograd.Function):
None, None,
None, None,
None, None,
d_softmax_offset,
) )
...@@ -1390,6 +1397,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1390,6 +1397,7 @@ class FusedAttention(torch.nn.Module):
attention_type: str = "self", attention_type: str = "self",
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
deterministic: bool = False, deterministic: bool = False,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1402,6 +1410,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1402,6 +1410,7 @@ class FusedAttention(torch.nn.Module):
) == "1" and get_device_compute_capability() == (9, 0) ) == "1" and get_device_compute_capability() == (9, 0)
self.layer_number = 1 if layer_number is None else layer_number self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic self.deterministic = deterministic
self.softmax_type = softmax_type
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
""" """
...@@ -1453,6 +1462,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1453,6 +1462,7 @@ class FusedAttention(torch.nn.Module):
quantizers=None, quantizers=None,
pad_between_seqs: bool = False, pad_between_seqs: bool = False,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
softmax_offset: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""fused attention fprop""" """fused attention fprop"""
assert ( assert (
...@@ -1603,6 +1613,8 @@ class FusedAttention(torch.nn.Module): ...@@ -1603,6 +1613,8 @@ class FusedAttention(torch.nn.Module):
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
quantizers=quantizers, quantizers=quantizers,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
softmax_type=self.softmax_type,
softmax_offset=softmax_offset,
) )
else: else:
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
...@@ -1626,6 +1638,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1626,6 +1638,7 @@ class FusedAttention(torch.nn.Module):
qkv_layout, qkv_layout,
core_attention_bias_type, core_attention_bias_type,
attn_mask_type, attn_mask_type,
self.softmax_type,
window_size, window_size,
None, # rng_gen None, # rng_gen
fused_attention_backend, fused_attention_backend,
...@@ -1634,6 +1647,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1634,6 +1647,7 @@ class FusedAttention(torch.nn.Module):
fp8_meta, fp8_meta,
quantizers, quantizers,
self.deterministic, self.deterministic,
softmax_offset,
) )
# ...hd -> ...(hd) # ...hd -> ...(hd)
......
...@@ -46,6 +46,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import ( ...@@ -46,6 +46,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
_cu_seqlens_info_with_cp_cache = {} _cu_seqlens_info_with_cp_cache = {}
_seq_chunk_ids_cache_for_reordering_before_attn = {} _seq_chunk_ids_cache_for_reordering_before_attn = {}
_seq_chunk_ids_cache_for_reordering_after_attn = {} _seq_chunk_ids_cache_for_reordering_after_attn = {}
_softmax_offset_chunk_ids_cache = {}
def flash_attn_p2p_communicate( def flash_attn_p2p_communicate(
...@@ -318,6 +319,55 @@ def flash_attn_a2a_communicate( ...@@ -318,6 +319,55 @@ def flash_attn_a2a_communicate(
return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs
def flash_attn_a2a_communicate_softmax_offset(
tensor: torch.Tensor,
h_dim: int,
cp_size: int,
cp_group: dist_group_type,
cp_stream: torch.cuda.Stream,
before_attn: bool,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Split/AllGather communication for softmax offset."""
if tensor is None:
return None
global _softmax_offset_chunk_ids_cache
device = tensor.device
if (cp_size, device) not in _softmax_offset_chunk_ids_cache:
chunk_ids = torch.arange(cp_size, dtype=torch.int32, device=device)
_softmax_offset_chunk_ids_cache[(cp_size, device)] = chunk_ids
else:
chunk_ids = _softmax_offset_chunk_ids_cache[(cp_size, device)]
if before_attn:
# softmax_offset: split round-robin to CP ranks
# [1, h, 1, 1] -> [1, cp, h//cp, 1, 1]
shape = tensor.shape
tensor = tensor.view(
*shape[:h_dim], cp_size, shape[h_dim] // cp_size, *shape[(h_dim + 1) :]
)
rank = get_distributed_rank(cp_group)
output = torch.index_select(tensor, dim=h_dim, index=chunk_ids[rank])
output = output.view(*shape[:h_dim], -1, *shape[(h_dim + 1) :])
else:
# d_softmax_offset: all-gather from all ranks to all ranks
# [1, h//cp, 1, 1] -> [1, h, 1, 1]
inp = tensor.view(-1)
output = torch.empty(cp_size * inp.shape[0], dtype=tensor.dtype, device=device)
with torch.cuda.stream(cp_stream):
torch.distributed.all_gather_into_tensor(
output,
inp,
group=cp_group,
async_op=False,
)
torch.cuda.current_stream().wait_stream(cp_stream)
output = output.view(
*tensor.shape[:h_dim], cp_size * tensor.shape[h_dim], *tensor.shape[h_dim + 1 :]
)
return output
def _get_cu_seqlens_info_with_cp( def _get_cu_seqlens_info_with_cp(
batch_size: int, batch_size: int,
max_seqlen: int, max_seqlen: int,
...@@ -1854,7 +1904,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1854,7 +1904,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd( dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_q_per_step[cp_size - i - 1],
...@@ -2014,7 +2064,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2014,7 +2064,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd( dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv // 2, ctx.max_seqlen_kv // 2,
cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_q_per_step[cp_size - i - 1],
...@@ -2171,7 +2221,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2171,7 +2221,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd( dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd(
ctx.max_seqlen_q // 2, ctx.max_seqlen_q // 2,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_q_per_step[cp_size - i - 1],
...@@ -2289,7 +2339,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2289,7 +2339,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd( dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_q_per_step[cp_size - i - 1],
...@@ -3122,7 +3172,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3122,7 +3172,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape)
if ctx.use_fused_attention: if ctx.use_fused_attention:
aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]]
dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd( dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
cu_seqlens_q, cu_seqlens_q,
...@@ -3283,6 +3333,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3283,6 +3333,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cp_stream, cp_stream,
quantizers, quantizers,
use_flash_attn_3, use_flash_attn_3,
softmax_type,
softmax_offset,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
...@@ -3391,6 +3443,10 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3391,6 +3443,10 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
q, k, v = flash_attn_a2a_communicate( q, k, v = flash_attn_a2a_communicate(
[q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
) )
if softmax_type != "vanilla":
softmax_offset = flash_attn_a2a_communicate_softmax_offset(
softmax_offset, 1, cp_size, cp_group, cp_stream, True
)
if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_f16, k_f16, v_f16 = q, k, v q_f16, k_f16, v_f16 = q, k, v
...@@ -3430,6 +3486,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3430,6 +3486,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cu_seqlens_kv_padded=cu_seqlens_kv_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
window_size=window_size, window_size=window_size,
**fp8_meta_kwargs, **fp8_meta_kwargs,
softmax_type=softmax_type,
softmax_offset=softmax_offset,
) )
if fp8: if fp8:
out = out._data out = out._data
...@@ -3532,6 +3590,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3532,6 +3590,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
ctx.use_flash_attn_3 = use_flash_attn_3 ctx.use_flash_attn_3 = use_flash_attn_3
ctx.softmax_type = softmax_type
ctx.qkv_dtype = qkv_dtype ctx.qkv_dtype = qkv_dtype
ctx.dQKV_quantizer = dQKV_quantizer ctx.dQKV_quantizer = dQKV_quantizer
...@@ -3695,7 +3754,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3695,7 +3754,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
dout_part, fake_dtype=dout_dtype, internal=True dout_part, fake_dtype=dout_dtype, internal=True
) )
dq, dk, dv, _ = fused_attn_bwd( dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
cu_seqlens_q, cu_seqlens_q,
...@@ -3719,6 +3778,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3719,6 +3778,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
window_size=ctx.window_size, window_size=ctx.window_size,
deterministic=ctx.deterministic, deterministic=ctx.deterministic,
**fp8_meta_kwargs, **fp8_meta_kwargs,
softmax_type=ctx.softmax_type,
) )
if ctx.fp8: if ctx.fp8:
dq = dq._data dq = dq._data
...@@ -3763,6 +3823,17 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3763,6 +3823,17 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]
d_bias = None
d_softmax_offset = None
if ctx.use_fused_attention:
if ctx.attn_bias_type not in ["no_bias", "alibi"]:
d_bias = rest[0]
if ctx.softmax_type != "vanilla":
d_softmax_offset = rest[1]
d_softmax_offset = flash_attn_a2a_communicate_softmax_offset(
d_softmax_offset, 1, cp_size, ctx.cp_group, ctx.cp_stream, False
)
if ctx.fp8: if ctx.fp8:
dq = ctx.dQKV_quantizer.create_tensor_from_data( dq = ctx.dQKV_quantizer.create_tensor_from_data(
dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
...@@ -3793,6 +3864,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3793,6 +3864,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
None, None,
None, None,
None, None,
d_bias,
None, None,
None, None,
None, None,
...@@ -3803,6 +3875,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3803,6 +3875,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
None, None,
None, None,
None, None,
d_softmax_offset,
) )
...@@ -3835,6 +3908,8 @@ def attn_forward_func_with_cp( ...@@ -3835,6 +3908,8 @@ def attn_forward_func_with_cp(
quantizers=None, quantizers=None,
pad_between_seqs=False, pad_between_seqs=False,
use_flash_attn_3=False, use_flash_attn_3=False,
softmax_type="vanilla",
softmax_offset=None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Attention implementation with context parallelism (CP). CP partitions tensors along the sequence Attention implementation with context parallelism (CP). CP partitions tensors along the sequence
...@@ -3911,23 +3986,23 @@ def attn_forward_func_with_cp( ...@@ -3911,23 +3986,23 @@ def attn_forward_func_with_cp(
else: else:
assert isinstance( assert isinstance(
cp_group, dist_group_type cp_group, dist_group_type
), f"Unsupported process group for CP communication type {cp_comm_type}!" ), f"cp_group must be {dist_group_type} type for {cp_comm_type=}!"
assert qkv_format in [ assert qkv_format in [
"bshd", "bshd",
"sbhd", "sbhd",
"thd", "thd",
], f"QKV format of {qkv_format} is not supported with context parallelism!" ], f"Context parallelism does not support {qkv_format=}!"
assert ( assert (
qkv_format != "sbhd" or use_fused_attention qkv_format != "sbhd" or use_fused_attention
), "FlashAttention does not support sbhd format!" ), "Context parallelism does not support FlashAttention backend with qkv_format = 'sbhd'!"
assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), ( assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
"""Attention bias is only supported with FusedAttention and "causal" """ "Context parallelism only supports attention bias with FusedAttention backend and"
"""or "no_mask" mask types!""" " non-padding mask types!"
) )
assert qkv_format != "thd" or ( assert qkv_format != "thd" or (
cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
), "cu_seqlens_padded cannot be None with context parallelism + THD format!" ), "cu_seqlens_padded can not be None for context parallelism and qkv_format = 'thd'!"
sliding_window_attn = ( sliding_window_attn = (
window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
...@@ -3935,13 +4010,28 @@ def attn_forward_func_with_cp( ...@@ -3935,13 +4010,28 @@ def attn_forward_func_with_cp(
assert not sliding_window_attn or cp_comm_type in [ assert not sliding_window_attn or cp_comm_type in [
"a2a", "a2a",
"all_gather", "all_gather",
], "The context parallel running configs cannot support sliding window attetnion!" ], "Context parallelism does not support sliding window attention with {cp_comm_type=}!"
enable_mla = k.shape[-1] != v.shape[-1] enable_mla = k.shape[-1] != v.shape[-1]
assert not enable_mla or cp_comm_type in [ assert not enable_mla or cp_comm_type in [
"p2p", "p2p",
"a2a+p2p", "a2a+p2p",
], "The context parallel running configs cannot support MLA!" ], "Context parallelism does not support MLA with {cp_comm_type=}!"
if fp8 and fp8_meta is not None:
if fp8_meta["recipe"].fp8_dpa:
assert (
softmax_type == "vanilla"
), "Context parallelism does not support {softmax_type=} with FP8 attention!"
assert (
softmax_type == "vanilla" or use_fused_attention
), "Context parallelism only supports {softmax_type=} with FusedAttention backend!"
assert (
softmax_type == "vanilla" or cp_comm_type == "a2a"
), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!"
assert (
softmax_type == "vanilla" or qkv_format != "thd"
), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!"
args = [ args = [
is_training, is_training,
...@@ -3982,7 +4072,17 @@ def attn_forward_func_with_cp( ...@@ -3982,7 +4072,17 @@ def attn_forward_func_with_cp(
args += [window_size, cp_group, cp_stream, use_flash_attn_3] args += [window_size, cp_group, cp_stream, use_flash_attn_3]
out = AttnFuncWithCPAndKVAllGather.apply(*args) out = AttnFuncWithCPAndKVAllGather.apply(*args)
elif cp_comm_type == "a2a": elif cp_comm_type == "a2a":
args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_flash_attn_3] args += [
window_size,
fp8,
fp8_meta,
cp_group,
cp_stream,
quantizers,
use_flash_attn_3,
softmax_type,
softmax_offset,
]
out = AttnFuncWithCPAndQKVOA2A.apply(*args) out = AttnFuncWithCPAndQKVOA2A.apply(*args)
else: else:
raise ValueError(f"Unsupported communication type: {cp_comm_type}!") raise ValueError(f"Unsupported communication type: {cp_comm_type}!")
......
...@@ -11,6 +11,7 @@ import warnings ...@@ -11,6 +11,7 @@ import warnings
import logging import logging
import torch import torch
from torch.nn.parameter import Parameter
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.utils import get_cudnn_version
...@@ -168,6 +169,17 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -168,6 +169,17 @@ class DotProductAttention(TransformerEngineBaseModule):
softmax_scale: Optional[float], default = `None` softmax_scale: Optional[float], default = `None`
softmax scale for the attention scores. If `None`, defaults to softmax scale for the attention scores. If `None`, defaults to
`1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`. `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -223,6 +235,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -223,6 +235,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_stream: torch.cuda.Stream = None, cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p", cp_comm_type: str = "p2p",
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -307,6 +320,20 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -307,6 +320,20 @@ class DotProductAttention(TransformerEngineBaseModule):
self.attention_type = attention_type self.attention_type = attention_type
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.softmax_type = softmax_type
if self.softmax_type == "vanilla":
self.softmax_offset = None
if self.softmax_type == "off-by-one":
self.softmax_offset = torch.zeros(
self.num_attention_heads // self.tp_size, device="cuda"
)
if self.softmax_type == "learnable":
self.register_parameter(
"softmax_offset",
Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="cuda")),
get_rng_state_tracker=get_rng_state_tracker,
)
attn_kwargs = { attn_kwargs = {
"attention_dropout": attention_dropout, "attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx, "attention_dropout_ctx": attention_dropout_ctx,
...@@ -328,6 +355,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -328,6 +355,7 @@ class DotProductAttention(TransformerEngineBaseModule):
layer_number=layer_number, layer_number=layer_number,
deterministic=self.deterministic, deterministic=self.deterministic,
**attn_kwargs, **attn_kwargs,
softmax_type=self.softmax_type,
) )
self.unfused_attention = UnfusedDotProductAttention( self.unfused_attention = UnfusedDotProductAttention(
...@@ -335,6 +363,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -335,6 +363,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_type=attention_type, attention_type=attention_type,
**attn_kwargs, **attn_kwargs,
layer_number=layer_number, layer_number=layer_number,
softmax_type=self.softmax_type,
) )
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
...@@ -634,6 +663,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -634,6 +663,7 @@ class DotProductAttention(TransformerEngineBaseModule):
query_layer, query_layer,
num_gemms=3, num_gemms=3,
allow_non_contiguous=True, allow_non_contiguous=True,
allow_different_data_and_param_types=self.softmax_type != "vanilla",
) as query_layer: ) as query_layer:
# checks for RNG # checks for RNG
if self.rng_states_tracker is not None and is_graph_capturing(): if self.rng_states_tracker is not None and is_graph_capturing():
...@@ -922,6 +952,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -922,6 +952,7 @@ class DotProductAttention(TransformerEngineBaseModule):
False False
), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes" ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"
# check if there is padding between sequences when qkv_format='thd'
if pad_between_seqs is None: if pad_between_seqs is None:
if qkv_format == "thd": if qkv_format == "thd":
pad_between_seqs = ( pad_between_seqs = (
...@@ -957,11 +988,13 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -957,11 +988,13 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
context_parallel=context_parallel, context_parallel=context_parallel,
cp_comm_type=self.cp_comm_type,
deterministic=self.deterministic, deterministic=self.deterministic,
is_training=self.training, is_training=self.training,
fp8=self.fp8, fp8=self.fp8,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
inference_params=inference_params, inference_params=inference_params,
softmax_type=self.softmax_type,
) )
global _attention_backends global _attention_backends
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
...@@ -1022,6 +1055,12 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1022,6 +1055,12 @@ class DotProductAttention(TransformerEngineBaseModule):
) )
# run attention # run attention
softmax_offset = (
self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32)
if self.softmax_offset is not None
else None
)
if use_flash_attention: if use_flash_attention:
if core_attention_bias_type == "alibi": if core_attention_bias_type == "alibi":
alibi_slopes, _ = dpa_utils.get_alibi( alibi_slopes, _ = dpa_utils.get_alibi(
...@@ -1071,7 +1110,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1071,7 +1110,6 @@ class DotProductAttention(TransformerEngineBaseModule):
bias_dtype=query_layer.dtype, bias_dtype=query_layer.dtype,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
) )
# checkpoint_core_attention=False
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.fused_attention, self.fused_attention,
...@@ -1101,6 +1139,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1101,6 +1139,7 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers, quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
inference_params=inference_params, inference_params=inference_params,
softmax_offset=softmax_offset,
) )
return self.fused_attention( return self.fused_attention(
query_layer, query_layer,
...@@ -1129,6 +1168,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1129,6 +1168,7 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers, quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
inference_params=inference_params, inference_params=inference_params,
softmax_offset=softmax_offset,
) )
from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled
...@@ -1157,6 +1197,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1157,6 +1197,7 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
inference_params=inference_params, inference_params=inference_params,
softmax_offset=softmax_offset,
) )
return self.unfused_attention( return self.unfused_attention(
_alibi_cache, _alibi_cache,
...@@ -1173,5 +1214,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1173,5 +1214,6 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
inference_params=inference_params, inference_params=inference_params,
softmax_offset=softmax_offset,
) )
return None return None
...@@ -24,6 +24,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -24,6 +24,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
QKVLayout, QKVLayout,
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
SoftmaxType,
FusedAttnBackend, FusedAttnBackend,
META_QKV, META_QKV,
META_DQKV, META_DQKV,
...@@ -206,6 +207,8 @@ class AttentionParams: ...@@ -206,6 +207,8 @@ class AttentionParams:
Attention dropout. Attention dropout.
context_parallel: bool, default = `False` context_parallel: bool, default = `False`
Whether context parallelism is used or not. Whether context parallelism is used or not.
cp_comm_type: str, default = "p2p"
The communication type of context parallelism.
deterministic: bool, default = `False` deterministic: bool, default = `False`
Whether to run `DotProductAttention` with determinism or not. Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True` is_training: bool, default = `True`
...@@ -216,6 +219,8 @@ class AttentionParams: ...@@ -216,6 +219,8 @@ class AttentionParams:
The FP8 metadata tensor of `DotProductAttention`. The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None` inference_params: Optional[InferenceParams], default = `None`
Inference-related parameters. See InferenceParams for details. Inference-related parameters. See InferenceParams for details.
softmax_type: str, default = "vanilla"
The type of softmax operation. See DotProductAttention for details.
""" """
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
...@@ -237,11 +242,13 @@ class AttentionParams: ...@@ -237,11 +242,13 @@ class AttentionParams:
pad_between_seqs: bool = False pad_between_seqs: bool = False
attention_dropout: float = 0.0 attention_dropout: float = 0.0
context_parallel: bool = False context_parallel: bool = False
cp_comm_type: str = "p2p"
deterministic: bool = False deterministic: bool = False
is_training: bool = True is_training: bool = True
fp8: bool = False fp8: bool = False
fp8_meta: Union[Dict[str, Any], None] = None fp8_meta: Union[Dict[str, Any], None] = None
inference_params: Optional[InferenceParams] = None inference_params: Optional[InferenceParams] = None
softmax_type: str = "vanilla"
def __eq__(self, other): def __eq__(self, other):
""" """
...@@ -308,11 +315,13 @@ def get_attention_backend( ...@@ -308,11 +315,13 @@ def get_attention_backend(
pad_between_seqs = attention_params.pad_between_seqs pad_between_seqs = attention_params.pad_between_seqs
attention_dropout = attention_params.attention_dropout attention_dropout = attention_params.attention_dropout
context_parallel = attention_params.context_parallel context_parallel = attention_params.context_parallel
cp_comm_type = attention_params.cp_comm_type
deterministic = attention_params.deterministic deterministic = attention_params.deterministic
is_training = attention_params.is_training is_training = attention_params.is_training
fp8 = attention_params.fp8 fp8 = attention_params.fp8
fp8_meta = attention_params.fp8_meta fp8_meta = attention_params.fp8_meta
inference_params = attention_params.inference_params inference_params = attention_params.inference_params
softmax_type = attention_params.softmax_type
# Run config # Run config
logger = logging.getLogger("DotProductAttention") logger = logging.getLogger("DotProductAttention")
...@@ -565,6 +574,51 @@ def get_attention_backend( ...@@ -565,6 +574,51 @@ def get_attention_backend(
logger.debug("Disabling FlashAttention 3 for dropout") logger.debug("Disabling FlashAttention 3 for dropout")
use_flash_attention_3 = False use_flash_attention_3 = False
# Filter: Softmax type
# context_parallel | softmax_type | supported backends
# ----------------------------------------------------------------------------------------------------
# no | vanilla | All
# no | off-by-one | FusedAttention, UnfusedDotProductAttention
# no | learnable | FusedAttention, UnfusedDotProductAttention
# yes | vanilla | FusedAttention, FlashAttention
# yes | off-by-one | FusedAttention
# yes | learnable | FusedAttention
if softmax_type != "vanilla":
logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type)
use_flash_attention = False
if fp8 and fp8_meta["recipe"].fp8_dpa:
logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type
)
use_unfused_attention = False
if qkv_format == "thd":
logger.debug(
"Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type
)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd",
softmax_type,
)
use_unfused_attention = False
if context_parallel:
logger.debug(
"Disabling UnfusedDotProductAttention for context parallelism with softmax_type"
" = %s",
softmax_type,
)
use_unfused_attention = False
if cp_comm_type != "a2a":
logger.debug(
"Disabling FusedAttention for context parallelism with softmax_type = %s and"
" cp_comm_type = %s",
softmax_type,
cp_comm_type,
)
use_fused_attention = False
# Filter: Context parallelism # Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends # qkv_format | attn_mask_type | attn_bias_type | supported backends
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
...@@ -806,6 +860,7 @@ def get_attention_backend( ...@@ -806,6 +860,7 @@ def get_attention_backend(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type], AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
attention_dropout, attention_dropout,
num_heads, num_heads,
num_gqa_groups, num_gqa_groups,
......
...@@ -135,6 +135,17 @@ class MultiheadAttention(torch.nn.Module): ...@@ -135,6 +135,17 @@ class MultiheadAttention(torch.nn.Module):
For that, please use `get_qkv_layout` to gain the layout information. For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None` name: str, default = `None`
name of the module, currently used for debugging purposes. name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -245,6 +256,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -245,6 +256,7 @@ class MultiheadAttention(torch.nn.Module):
qk_norm_before_rope: bool = False, qk_norm_before_rope: bool = False,
seq_length: Optional[int] = None, seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None, micro_batch_size: Optional[int] = None,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -262,6 +274,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -262,6 +274,7 @@ class MultiheadAttention(torch.nn.Module):
self.return_bias = return_bias self.return_bias = return_bias
self.cp_size = 1 self.cp_size = 1
self.cp_rank = 0 self.cp_rank = 0
self.softmax_type = softmax_type
kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads) kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)
...@@ -416,6 +429,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -416,6 +429,7 @@ class MultiheadAttention(torch.nn.Module):
tp_group=tp_group, tp_group=tp_group,
layer_number=self.layer_number, layer_number=self.layer_number,
attention_type=self.attention_type, attention_type=self.attention_type,
softmax_type=self.softmax_type,
) )
# Linear # Linear
......
...@@ -12,6 +12,7 @@ from transformer_engine_torch import ( ...@@ -12,6 +12,7 @@ from transformer_engine_torch import (
NVTE_QKV_Format, NVTE_QKV_Format,
NVTE_Bias_Type, NVTE_Bias_Type,
NVTE_Mask_Type, NVTE_Mask_Type,
NVTE_Softmax_Type,
NVTE_Fused_Attn_Backend, NVTE_Fused_Attn_Backend,
) )
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
...@@ -86,6 +87,12 @@ AttnMaskType = { ...@@ -86,6 +87,12 @@ AttnMaskType = {
"padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK, "padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
} }
SoftmaxType = {
"vanilla": NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX,
"off-by-one": NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX,
"learnable": NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX,
}
FusedAttnBackend = { FusedAttnBackend = {
"F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, "F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen,
"F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, "F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
...@@ -131,8 +138,10 @@ def fused_attn_fwd( ...@@ -131,8 +138,10 @@ def fused_attn_fwd(
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None, rng_gen: torch.Generator = None,
softmax_offset: torch.Tensor = None,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for separate QKV input. """Fused Attention FWD for separate QKV input.
...@@ -197,6 +206,8 @@ def fused_attn_fwd( ...@@ -197,6 +206,8 @@ def fused_attn_fwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1) window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
...@@ -205,6 +216,9 @@ def fused_attn_fwd( ...@@ -205,6 +216,9 @@ def fused_attn_fwd(
rng_gen: torch.Generator, default = None rng_gen: torch.Generator, default = None
random number generator; random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
softmax_offset: torch.Tensor, default = None
softmax offset tensor in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
Returns Returns
---------- ----------
...@@ -286,6 +300,7 @@ def fused_attn_fwd( ...@@ -286,6 +300,7 @@ def fused_attn_fwd(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type], AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size, window_size,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
...@@ -300,6 +315,7 @@ def fused_attn_fwd( ...@@ -300,6 +315,7 @@ def fused_attn_fwd(
s_quantizer, s_quantizer,
o_quantizer, o_quantizer,
attn_bias, attn_bias,
softmax_offset,
rng_gen, rng_gen,
rng_elts_per_thread, rng_elts_per_thread,
) )
...@@ -333,6 +349,7 @@ def fused_attn_bwd( ...@@ -333,6 +349,7 @@ def fused_attn_bwd(
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False, deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
...@@ -398,6 +415,8 @@ def fused_attn_bwd( ...@@ -398,6 +415,8 @@ def fused_attn_bwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1) window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
...@@ -417,6 +436,9 @@ def fused_attn_bwd( ...@@ -417,6 +436,9 @@ def fused_attn_bwd(
d_bias: torch.Tensor, optional d_bias: torch.Tensor, optional
gradient tensor of Bias when attn_bias_type is "pre_scale_bias" gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
or "post_scale_bias"; same data type and shape as Bias or "post_scale_bias"; same data type and shape as Bias
d_softmax_offset: torch.Tensor, optional
gradient tensor of softmax offset in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
""" """
if attn_scale is None: if attn_scale is None:
d = q.size(-1) d = q.size(-1)
...@@ -454,6 +476,7 @@ def fused_attn_bwd( ...@@ -454,6 +476,7 @@ def fused_attn_bwd(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type], AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size, window_size,
deterministic, deterministic,
cu_seqlens_q, cu_seqlens_q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment