Unverified Commit 32db3928 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Integrate cuDNN frontend v1 to fused attention (#497)



* Integrate cuDNN frontend v1 to fused attention and miscellaneous fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax/paddle for unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax/pytorch lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* simplify stride generation
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix and/or logic in get_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix flag_max512 and test_numerics
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove v.contiguous() since get_qkv_layout covers it
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* skip fp8 tests for sm89
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further fix jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert mask type to comma-separated list
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last two commits
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* integrate v1/pre-release-5
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* cleanup prerelease5 integration and fix FA2.1 commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force dropout to 0 if not training
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix Jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* testing bias/alibi and padding+causal; add alibi to unfused DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* set flag_arb to false when non determinism is not allowed
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* followup on prev commit; remove redundant python env var setting
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: minor tweaks for tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* prepare for tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix determinism logic for fused attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add bias to bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix gpt_checkpointing/dpa_accuracy problem
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix some seg fault issues
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add failure notes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove use of non-deter var for backend selection
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for lint and CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix workspace size in bwd and uncomment bias test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix get_alibi and remove check_support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update tests status
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove workspace_opt from FADescriptor_v1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable arbitrary backend + post scale bias in Jax; waiting on PR 525
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up bhsd order
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* swap bias/rng_state order in aux_ctx_tensor and add bias to aux_ctx_tensor in _qkvpacked/_kvpacked API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove support for padding_causal + cross for max512
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change alibi bias to float32 for bias_1_4/5 tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further clean up tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix thd fwd output shape for FlashAttention and add backend info for DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix definition of workspace limit when dbias is present
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further tweak DP_WORKSPACE_LIMIT definition
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disallow alibi+no_mask for sdpa flash and update alibi tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update jax/paddle after PR525 and fix DP_WORKSPACE_LIMIT for dbias Jax tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable dbias for non-hopper archs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix layernorm lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remode unused arg for lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove build dir in setup.py
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change selection logic to prefer fused attn on sm90
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix distributed jax test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix h and s order in header
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update to cudnn fe v1 branch
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove manual setting of workopt path due to dbias after v1 update
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix paddle CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add post_scale_bias and alibi to sdpa flash support matrix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix support matrix in header files
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move headers back to .cu and change seed/offset to int64
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update Megatron commit in L1 test and remove all prints in fused attn test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix L1 Megatron test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 arg in L1 Megatron script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* print only when debug flag is on
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove checkpointing loading to avoid loading other tests results
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent ff760a9d
Subproject commit 12f35fa2be5994c1106367cac2fba21457b064f4 Subproject commit 9f82dda5c029d15a5f371f0fe003dc0c74a0c987
...@@ -8,6 +8,7 @@ set -e ...@@ -8,6 +8,7 @@ set -e
git clone https://github.com/NVIDIA/Megatron-LM.git git clone https://github.com/NVIDIA/Megatron-LM.git
cd Megatron-LM cd Megatron-LM
git checkout f24fac4ed0dcf0522056521a93445d9a82f501a9 git checkout bcce6f54e075e3c3374ea67adefe54f3f2da2b07
sed -i -e '1504,1505d' megatron/model/transformer.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_convergence.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_convergence.py
python $TE_PATH/tests/pytorch/distributed/print_logs.py python $TE_PATH/tests/pytorch/distributed/print_logs.py
...@@ -77,10 +77,11 @@ class TestDistributedSelfAttn: ...@@ -77,10 +77,11 @@ class TestDistributedSelfAttn:
is_training = True is_training = True
scaling_factor = 1.0 scaling_factor = 1.0
_, seqlen, _, _, hidden = data_shape _, seqlen, _, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type, if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
attn_mask_type, dropout_prob, seqlen, seqlen, hidden): attn_mask_type, dropout_prob, num_head, num_head,
seqlen, seqlen, hidden):
pytest.skip(f"No FusedAttn backwend found") pytest.skip(f"No FusedAttn backwend found")
def target_func(qkv, bias, mask): def target_func(qkv, bias, mask):
...@@ -182,10 +183,11 @@ class TestDistributedCrossAttn: ...@@ -182,10 +183,11 @@ class TestDistributedCrossAttn:
is_training = True is_training = True
scaling_factor = 1.0 scaling_factor = 1.0
_, seqlen, _, hidden = data_shape _, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type, if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type,
attn_mask_type, dropout_prob, seqlen, seqlen, hidden): attn_mask_type, dropout_prob, num_head, num_head,
seqlen, seqlen, hidden):
pytest.skip(f"No FusedAttn backwend found") pytest.skip(f"No FusedAttn backwend found")
def target_func(q, kv, mask): def target_func(q, kv, mask):
......
...@@ -180,12 +180,14 @@ class TestSelfFusedAttn(): ...@@ -180,12 +180,14 @@ class TestSelfFusedAttn():
@staticmethod @staticmethod
def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype, def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype,
head_dim): num_heads_q, num_heads_kv, head_dim):
assert isinstance(backend, Backend) assert isinstance(backend, Backend)
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type, if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
attn_mask_type, dropout_probability, s, s, head_dim): attn_mask_type, dropout_probability,
num_heads_q, num_heads_kv,
s, s, head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.") pytest.skip("Unsupported inputs combination or device compute capability.")
def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend, def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
...@@ -197,6 +199,8 @@ class TestSelfFusedAttn(): ...@@ -197,6 +199,8 @@ class TestSelfFusedAttn():
backend=backend, backend=backend,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
dtype=dtype, dtype=dtype,
num_heads_q=h,
num_heads_kv=h,
head_dim=d) head_dim=d)
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
......
...@@ -48,11 +48,8 @@ class TestGroupSharding(unittest.TestCase): ...@@ -48,11 +48,8 @@ class TestGroupSharding(unittest.TestCase):
def _get_model_and_optimizer(self, model, stage): def _get_model_and_optimizer(self, model, stage):
if stage == 1: if stage == 1:
optimizer = DygraphShardingOptimizer( optimizer = DygraphShardingOptimizer(
hcg=fleet.get_hybrid_communicate_group(), paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters()),
user_defined_strategy=self.strategy, fleet.get_hybrid_communicate_group(),
params=model.parameters(),
inner_optimizer_class=paddle.optimizer.AdamW,
learning_rate=0.01,
) )
model = fleet.distributed_model(model) model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer) optimizer = fleet.distributed_optimizer(optimizer)
......
...@@ -634,9 +634,11 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, ...@@ -634,9 +634,11 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
# Skip if cuDNN fused attention is not supported # Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported( if not is_fused_attention_supported(
head_size=head_size, num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen, q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
head_size=head_size,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bs3hd" if attn_type == "self" else "bshd_bs2hd", qkv_layout="bs3hd" if attn_type == "self" else "bshd_bs2hd",
...@@ -762,9 +764,11 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -762,9 +764,11 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
# Skip if cuDNN fused attention is not supported # Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported( if not is_fused_attention_supported(
head_size=hidden_size // num_heads, num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen, q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bs3hd", qkv_layout="bs3hd",
...@@ -940,9 +944,11 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -940,9 +944,11 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
# Skip if cuDNN fused attention is not supported # Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported( if not is_fused_attention_supported(
head_size=hidden_size // num_heads, num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen, q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bs3hd", qkv_layout="bs3hd",
...@@ -952,6 +958,8 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -952,6 +958,8 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
pytest.skip("cuDNN fused attention is not supported") pytest.skip("cuDNN fused attention is not supported")
if not is_fused_attention_supported( if not is_fused_attention_supported(
head_size=hidden_size // num_heads, head_size=hidden_size // num_heads,
num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen, q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
dtype=math_dtype, dtype=math_dtype,
......
...@@ -688,9 +688,11 @@ class TestFusedAttn: ...@@ -688,9 +688,11 @@ class TestFusedAttn:
else "bshd_bs2hd" else "bshd_bs2hd"
) )
fused_attention_backend = get_fused_attention_backend( fused_attention_backend = get_fused_attention_backend(
head_size=self.head_size, num_heads=self.num_heads,
num_gqa_groups=self.num_heads,
q_seqlen=self.q_seqlen, q_seqlen=self.q_seqlen,
kv_seqlen=self.kv_seqlen, kv_seqlen=self.kv_seqlen,
head_size=self.head_size,
dtype=self.dtype, dtype=self.dtype,
dropout=self.dropout_prob, dropout=self.dropout_prob,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -774,9 +776,11 @@ class TestFusedAttn: ...@@ -774,9 +776,11 @@ class TestFusedAttn:
test self attention forward + backward test self attention forward + backward
""" """
if not is_fused_attention_supported( if not is_fused_attention_supported(
head_size=d, num_heads=h,
num_gqa_groups=h,
q_seqlen=s, q_seqlen=s,
kv_seqlen=s, kv_seqlen=s,
head_size=d,
dtype=dtype, dtype=dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bs3hd", qkv_layout="bs3hd",
...@@ -799,9 +803,11 @@ class TestFusedAttn: ...@@ -799,9 +803,11 @@ class TestFusedAttn:
test cross attention forward + backward test cross attention forward + backward
""" """
if not is_fused_attention_supported( if not is_fused_attention_supported(
head_size=d, num_heads=h,
num_gqa_groups=h,
q_seqlen=s_q, q_seqlen=s_q,
kv_seqlen=s_kv, kv_seqlen=s_kv,
head_size=d,
dtype=dtype, dtype=dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bshd_bs2hd", qkv_layout="bshd_bs2hd",
...@@ -825,9 +831,11 @@ class TestFusedAttn: ...@@ -825,9 +831,11 @@ class TestFusedAttn:
test flash attention forward + backward test flash attention forward + backward
""" """
if not is_fused_attention_supported( if not is_fused_attention_supported(
head_size=d, num_heads=h,
num_gqa_groups=h,
q_seqlen=s, q_seqlen=s,
kv_seqlen=s, kv_seqlen=s,
head_size=d,
dtype=dtype, dtype=dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bs3hd", qkv_layout="bs3hd",
......
...@@ -102,9 +102,11 @@ def set_random_seed(seed): ...@@ -102,9 +102,11 @@ def set_random_seed(seed):
def get_fused_attention_backend( def get_fused_attention_backend(
head_size: int, num_heads: int,
num_gqa_groups: int,
q_seqlen: int, q_seqlen: int,
kv_seqlen: int, kv_seqlen: int,
head_size: int,
dtype: Union[paddle.dtype, str], dtype: Union[paddle.dtype, str],
dropout: float, dropout: float,
qkv_layout: str = "bs3hd", qkv_layout: str = "bs3hd",
...@@ -125,6 +127,8 @@ def get_fused_attention_backend( ...@@ -125,6 +127,8 @@ def get_fused_attention_backend(
AttnBiasType[bias_type], AttnBiasType[bias_type],
AttnMaskType[mask_type], AttnMaskType[mask_type],
dropout, dropout,
num_heads,
num_gqa_groups,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
head_size, head_size,
...@@ -132,9 +136,11 @@ def get_fused_attention_backend( ...@@ -132,9 +136,11 @@ def get_fused_attention_backend(
def is_fused_attention_supported( def is_fused_attention_supported(
head_size: int, num_heads: int,
num_gqa_groups: int,
q_seqlen: int, q_seqlen: int,
kv_seqlen: int, kv_seqlen: int,
head_size: int,
dtype: Union[paddle.dtype, str], dtype: Union[paddle.dtype, str],
dropout: float, dropout: float,
qkv_layout: str = "bs3hd", qkv_layout: str = "bs3hd",
...@@ -143,9 +149,11 @@ def is_fused_attention_supported( ...@@ -143,9 +149,11 @@ def is_fused_attention_supported(
) -> bool: ) -> bool:
"""Check if cuDNN fused attention is supported for attention config""" """Check if cuDNN fused attention is supported for attention config"""
backend = get_fused_attention_backend( backend = get_fused_attention_backend(
head_size=head_size, num_heads=num_heads,
num_gqa_groups=num_gqa_groups,
q_seqlen=q_seqlen, q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
head_size=head_size,
dtype=dtype, dtype=dtype,
dropout=dropout, dropout=dropout,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
......
...@@ -81,7 +81,6 @@ options=" \ ...@@ -81,7 +81,6 @@ options=" \
--merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt \ --merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt \
--save-interval ${SAVE_INTERVAL} \ --save-interval ${SAVE_INTERVAL} \
--save ${CHECKPOINT_DIR} \ --save ${CHECKPOINT_DIR} \
--load ${CHECKPOINT_DIR} \
--split ${SPLIT} \ --split ${SPLIT} \
--clip-grad ${CLIP_GRAD} \ --clip-grad ${CLIP_GRAD} \
--weight-decay ${WEIGHT_DECAY} \ --weight-decay ${WEIGHT_DECAY} \
...@@ -90,8 +89,6 @@ options=" \ ...@@ -90,8 +89,6 @@ options=" \
--init-method-std ${INIT_METHOD_STD} \ --init-method-std ${INIT_METHOD_STD} \
--log-params-norm \ --log-params-norm \
--log-num-zeros-in-grad \ --log-num-zeros-in-grad \
--no-query-key-layer-scaling \
--DDP-impl local \
--transformer-impl ${TRANSFORMER_IMPL} \ --transformer-impl ${TRANSFORMER_IMPL} \
--tensorboard-dir ${TENSORBOARD_DIR} \ --tensorboard-dir ${TENSORBOARD_DIR} \
--fp8-margin 0 \ --fp8-margin 0 \
...@@ -108,7 +105,7 @@ if [[ "$WGRAD_FUSION" == "False" ]]; then ...@@ -108,7 +105,7 @@ if [[ "$WGRAD_FUSION" == "False" ]]; then
fi fi
if [[ "$FP8" != "False" ]]; then if [[ "$FP8" != "False" ]]; then
options+=" --fp8-${FP8}" options+=" --fp8-format ${FP8}"
fi fi
if [[ "$DTYPE" != "fp32" ]]; then if [[ "$DTYPE" != "fp32" ]]; then
......
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import functools
from importlib.metadata import version from importlib.metadata import version
import os import os
import math
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
from pkg_resources import packaging from pkg_resources import packaging
...@@ -26,6 +28,10 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -26,6 +28,10 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_bwd, fused_attn_bwd,
fused_attn_fwd, fused_attn_fwd,
) )
from transformer_engine.pytorch.distributed import (
_set_cuda_rng_state,
CudaRNGStatesTracker,
)
import transformer_engine.pytorch.fp8 as fp8 import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import ( from transformer_engine.pytorch.module.base import (
TransformerEngineBaseModule, TransformerEngineBaseModule,
...@@ -36,231 +42,304 @@ from transformer_engine.pytorch.utils import ( ...@@ -36,231 +42,304 @@ from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
) )
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine_extensions import NVTE_Fused_Attn_Backend
# Only run FP8 tests on H100
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
# Initialize RNG state
seed = 1234 seed = 1234
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state() _cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state()
def _get_cudnn_version():
cudnn_version_encoded = ext.get_cudnn_version()
cudnn_major = cudnn_version_encoded // 1000
cudnn_minor = (cudnn_version_encoded - cudnn_major * 1000) // 100
cudnn_patch = cudnn_version_encoded - 1000 * cudnn_major - 100 * cudnn_minor
return [cudnn_major, cudnn_minor, cudnn_patch]
def reset_rng_states() -> None: def reset_rng_states() -> None:
"""revert back to initial RNG state.""" """Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state) torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state) _set_cuda_rng_state(_cuda_rng_state)
@functools.cache
_cudnn_version = _get_cudnn_version() def _cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)"""
encoded_version = ext.get_cudnn_version()
major, encoded_version = divmod(encoded_version, 1000)
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)
class ModelConfig: class ModelConfig:
def __init__( def __init__(
self, num_layers, hidden_size, num_attention_heads, head_dim, seq_len, self,
dropout_p, attn_mask_type, batch_size: int,
num_heads: int,
num_gqa_groups: int,
head_dim: int,
max_seqlen_q: int,
max_seqlen_kv: int,
dropout_p: float,
attn_mask_type: str,
attn_bias_type: str,
num_layers: int = 1,
): ):
self.num_layers = num_layers self.batch_size = batch_size
self.hidden_size = hidden_size self.num_heads = num_heads
self.num_attention_heads = num_attention_heads self.num_gqa_groups = num_gqa_groups
self.head_dim = head_dim self.head_dim = head_dim
assert (hidden_size == num_attention_heads * head_dim self.hidden_size = num_heads * head_dim
), """hidden_size must be = num_heads x head_dim.""" self.hidden_size_kv = num_gqa_groups * head_dim
self.seq_len = seq_len self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_kv
self.dropout_p = dropout_p self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
model_configs = { self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
"test1": ModelConfig(1, 1024, 16, 64, 128, 0.0, "causal"), self.num_layers = num_layers
"test2": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"),
"test3": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"),
"test4": ModelConfig(1, 3072, 24, 128, 2048, 0.0, "causal"),
"test5": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
}
param_types = [torch.float16]
if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16)
batch_sizes = [1, 32]
model_configs_lean = {
"test6": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
"test7": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"),
}
param_types_lean = [torch.bfloat16]
batch_sizes_lean = [2]
def _is_fused_attention_supported( def _is_fused_attention_supported(
config: ModelConfig, config: ModelConfig,
dtype: torch.dtype, dtype: torch.dtype,
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
bias_type: str = "no_bias", ) -> Tuple[bool, NVTE_Fused_Attn_Backend]:
) -> bool: """Check if FusedAttention supports a model configuration"""
backends = []
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
backend = tex.get_fused_attn_backend( backend = tex.get_fused_attn_backend(
TE_DType[dtype], TE_DType[dtype],
TE_DType[dtype], TE_DType[dtype],
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[bias_type], AttnBiasType[config.attn_bias_type],
AttnMaskType[config.attn_mask_type], AttnMaskType[config.attn_mask_type],
config.dropout_p, config.dropout_p,
config.seq_len, config.num_heads,
config.seq_len, config.num_gqa_groups,
config.max_seqlen_q,
config.max_seqlen_kv,
config.head_dim, config.head_dim,
) )
return backend != FusedAttnBackend["No_Backend"] if backend == FusedAttnBackend["FP8"]:
backends.append(backend)
def _is_flash_attention_supported(bias_type: str = "no_bias") -> bool: return True, backends
if backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
backends.append(backend)
return True, backends
if backend == FusedAttnBackend["F16_max512_seqlen"]:
backends.append(backend)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
backend = tex.get_fused_attn_backend(
TE_DType[dtype],
TE_DType[dtype],
QKVLayout[qkv_layout],
AttnBiasType[config.attn_bias_type],
AttnMaskType[config.attn_mask_type],
config.dropout_p,
config.num_heads,
config.num_gqa_groups,
config.max_seqlen_q,
config.max_seqlen_kv,
config.head_dim,
)
if backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
backends.append(backend)
return True, backends
return False, backends
@functools.cache
def _is_flash_attention_2_available() -> bool:
"""Check if flash-attn 2.0+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2")
@functools.cache
def _is_flash_attention_2_1() -> bool:
"""Check if flash-attn 2.0+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.1")
def _is_flash_attention_supported(config: ModelConfig) -> bool:
"""Check if FlashAttention supports a model configuration"""
if get_device_compute_capability() < (8, 0): if get_device_compute_capability() < (8, 0):
return False return False
if bias_type != "no_bias": if config.attn_bias_type != "no_bias":
return False
if config.num_heads != config.num_gqa_groups and not _is_flash_attention_2_available():
return False
if "causal" in config.attn_mask_type and config.attn_type == "cross":
if _is_flash_attention_2_1():
# FAv2.1 implements causal mask for cross attention differently
# https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag
return False
return True
def _is_unfused_attention_supported(config: ModelConfig) -> bool:
"""Check if UnfusedDotProductAttention supports a model configuration"""
if ("padding" in config.attn_mask_type):
return False
if ("causal" in config.attn_mask_type and config.attn_type == 'cross'):
return False return False
return True return True
model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
}
param_types = [torch.float16]
if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_lean) @pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs_base.keys())
@pytest.mark.parametrize("ckpt_attn", [True, False]) @pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"]) @pytest.mark.parametrize("workspace_opt", [True, False])
def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type): @pytest.mark.parametrize("qkv_layout", [None])
"""Test DotProductAttention module with different backends""" def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout):
"""Test DotProductAttention module"""
# Get configs # Get configs
config = model_configs[model]
tols = dict(atol=5e-3, rtol=5e-3) tols = dict(atol=5e-3, rtol=5e-3)
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2) tols = dict(atol=2.5e-2, rtol=2.5e-2)
config = model_configs[model]
if qkv_layout is None:
if config.attn_type == "self":
qkv_layout = "sb3hd"
else:
qkv_layout = "sbhd_sb2hd"
if "3" in qkv_layout and config.attn_type == "cross":
pytest.skip(
"No need to test this layout for cross attention"
)
# Skip if only unfused backend is supported # Skip if only unfused backend is supported
fused_attn_supported = _is_fused_attention_supported( unfused_attn_supported = _is_unfused_attention_supported(config)
config, if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
dtype, os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
bias_type=bias_type, fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout,
) )
flash_attn_supported = _is_flash_attention_supported(bias_type=bias_type) flash_attn_supported = _is_flash_attention_supported(config)
if not (fused_attn_supported or flash_attn_supported): if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip( pytest.skip("Less than two backends to compare.")
"Neither FusedAttention nor FlashAttention support this model config"
)
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( if unfused_attn_supported:
dtype, unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
bs, dtype, config, "UnfusedDotProductAttention", ckpt_attn, qkv_layout, workspace_opt,
config, )
"UnfusedDotProductAttention",
ckpt_attn,
bias_type,
)
# FusedAttention backend # FusedAttention backend
if fused_attn_supported: if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( if len(fused_attn_backend) == 1:
dtype, fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
bs, dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt,
config, )
"FusedAttention", if len(fused_attn_backend) == 2:
ckpt_attn, os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
bias_type, fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
) dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt,
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) )
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols) os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt,
)
# FlashAttention backend # FlashAttention backend
if flash_attn_supported: if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, dtype, config, "FlashAttention", ckpt_attn, qkv_layout, workspace_opt,
bs,
config,
"FlashAttention",
ckpt_attn,
bias_type,
) )
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type):
reset_rng_states() if unfused_attn_supported and fused_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "0" torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
os.environ["NVTE_FUSED_ATTN"] = "0" for i,_ in enumerate(unfused_attn_bwd):
if backend == "FlashAttention": torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
os.environ["NVTE_FLASH_ATTN"] = "1" if unfused_attn_supported and flash_attn_supported:
if backend == "FusedAttention": torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
os.environ["NVTE_FUSED_ATTN"] = "1" for i,_ in enumerate(flash_attn_bwd):
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported:
inp = torch.randn( torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
config.seq_len, bs, 3, config.num_attention_heads, config.head_dim, for i,_ in enumerate(flash_attn_bwd):
dtype=dtype).cuda() torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
inp.requires_grad=True if fused_attn_supported and len(fused_attn_backend) == 2:
seqlens = torch.empty(bs, dtype=torch.int32).cuda() torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
seqlens.fill_(config.seq_len) for i,_ in enumerate(fused_attn_bwd):
cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32) torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = torch.randn( @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
config.seq_len, bs, config.num_attention_heads * config.head_dim, @pytest.mark.parametrize("dtype", param_types)
dtype = dtype).cuda() @pytest.mark.parametrize("model_configs", [model_configs_base])
if bias_type != "no_bias": @pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len, def test_dpa_checkpoint(dtype, model_configs, model):
dtype=dtype).cuda() """Test DotProductAttention module with checkpointing"""
else: test_dot_product_attention(dtype, model_configs, model, True, True, None)
bias = None
model_configs_mask = {
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() # test: b, h, hg, d, sq, skv, p, mask, bias
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
"mask_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
def get_dummy_cuda_rng_tracker(): "mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"""Get cuda rng tracker.""" "mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
return _DUMMY_CUDA_RNG_STATE_TRACKER "mask_3_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"mask_3_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
block = ( "mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding", "no_bias"),
DotProductAttention( "mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
config.num_attention_heads, "mask_5_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
config.head_dim, "mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"),
attention_dropout=config.dropout_p, "mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
sequence_parallel=False, "mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
tp_size=1, }
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type="self"
).to(dtype=dtype).cuda()
)
q = inp[:, :,0,:,:] @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
k = inp[:, :,1,:,:] @pytest.mark.parametrize("dtype", param_types_lean)
v = inp[:, :,2,:,:] @pytest.mark.parametrize("model_configs", [model_configs_mask])
op = block(q, k, v, @pytest.mark.parametrize("model", model_configs_mask.keys())
qkv_format='sbhd', def test_dpa_mask(dtype, model_configs, model):
cu_seqlens_q = cu_seqlens, """Test DotProductAttention module with different mask types"""
cu_seqlens_kv = cu_seqlens, test_dot_product_attention(dtype, model_configs, model, False, True, None)
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn, model_configs_bias = {
core_attention_bias_type=bias_type, # test: b, h, hg, d, sq, skv, p, mask, bias
core_attention_bias=bias) "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
op.backward(op_grad) "bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"),
"bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
"bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"),
"bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped
"bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped
"bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped
"bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped
"bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"),
"bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"bias_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias"), # skipped
"bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"),
"bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped
"bias_4_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias"), # skipped
"bias_4_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"), # skipped
"bias_4_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias"), # skipped
"bias_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), # skipped
"bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped
"bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped
}
return op, inp.grad @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None)
qkv_layouts = [ qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd', 'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
...@@ -269,54 +348,39 @@ qkv_layouts = [ ...@@ -269,54 +348,39 @@ qkv_layouts = [
#'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd', #'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd',
] ]
@pytest.mark.skipif( model_configs_layout = {
_cudnn_version < [8,9,5], reason="cuDNN 8.9.5+ is required.") # test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"),
"layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"),
}
@pytest.mark.skipif(_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("bs", batch_sizes_lean) @pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_lean.keys()) @pytest.mark.parametrize("model", model_configs_layout.keys())
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", qkv_layouts) @pytest.mark.parametrize("qkv_layout", qkv_layouts)
def test_dpa_qkv_layout(dtype, bs, model, workspace_opt, qkv_layout): def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts""" """Test DotProductAttention module with different QKV layouts"""
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout)
# Get configs
config = model_configs_lean[model] def _run_dot_product_attention(
tols = dict(atol=5e-3, rtol=5e-3) dtype: torch.dtype,
if dtype == torch.bfloat16: config: ModelConfig,
tols = dict(atol=2.5e-2, rtol=2.5e-2) backend: str,
ckpt_attn: bool,
# Skip if only unfused backend is supported qkv_layout: str,
fused_attn_supported = _is_fused_attention_supported(config, dtype) workspace_opt: bool,
flash_attn_supported = _is_flash_attention_supported() ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
if not (fused_attn_supported or flash_attn_supported): """Run DotProductAttention module with one forward pass and one backward pass"""
pytest.skip(
"Neither FusedAttention nor FlashAttention support this model config" # Set RNG and environment varables
) reset_rng_states()
# UnfusedDotProductAttention backend
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "UnfusedDotProductAttention", qkv_layout, workspace_opt)
# FusedAttention backend
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "FusedAttention", qkv_layout, workspace_opt)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
for i in range(len(unfused_attn_bwd)):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "FlashAttention", qkv_layout, workspace_opt)
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
for i in range(len(unfused_attn_bwd)):
torch.testing.assert_close(flash_attn_bwd[i], unfused_attn_bwd[i], **tols)
def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention": if backend == "FlashAttention":
...@@ -325,271 +389,284 @@ def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt): ...@@ -325,271 +389,284 @@ def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt):
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
dim_to_num = {'b': bs, # Create seqlens
's': config.seq_len, qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
'h': config.num_attention_heads, if "padding" in config.attn_mask_type or qkv_format == 'thd':
'd': config.head_dim, if config.attn_type == 'self':
't': bs * config.seq_len, seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
'3': 3, dtype=torch.int32, device="cuda")
'2': 2} seqlens_kv = seqlens_q
if config.attn_type == 'cross':
seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
dtype=torch.int32, device="cuda")
seqlens_kv = torch.randint(1, config.max_seqlen_kv, [config.batch_size],
dtype=torch.int32, device="cuda")
else:
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
dtype=torch.int32, device="cuda")
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
# Create attention mask if padding
attention_mask = None
if "padding" in config.attn_mask_type:
if config.attn_type == 'self':
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat([attention_mask_q,
torch.Tensor([True]*seqlens_q[i] + [False]*(config.max_seqlen_q-seqlens_q[i]))
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask = attention_mask_q.to(device="cuda")
if config.attn_type == 'cross':
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat([attention_mask_q,
torch.Tensor([True]*seqlens_q[i] + [False]*(config.max_seqlen_q-seqlens_q[i]))
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask_kv = torch.cat([attention_mask_kv, torch.Tensor(
[True]*seqlens_kv[i] + [False]*(config.max_seqlen_kv-seqlens_kv[i]))
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask = (
attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda"))
# Create input tensors
dim_to_num = {
'b' : config.batch_size,
'sq' : config.max_seqlen_q,
'skv': config.max_seqlen_kv,
'h' : config.num_heads,
'hg' : config.num_gqa_groups,
'd' : config.head_dim,
't' : cu_seqlens_q[-1],
'tg' : cu_seqlens_kv[-1],
'3' : 3,
'2' : 2,
}
inp = [] inp = []
for i,layout in enumerate(qkv_layout.split('_')): for i,layout in enumerate(qkv_layout.split('_')):
tensor_shape = [dim_to_num[j] for j in layout] layout = '_'.join(layout)
tensor = 0.1 * torch.randn(tensor_shape, dtype = dtype).cuda() if i == 0:
layout = layout.replace('s', 'sq')
else:
layout = layout.replace('s', 'skv')
layout = layout.replace('h', 'hg')
layout = layout.replace('t', 'tg')
tensor_shape = [dim_to_num[j] for j in layout.split('_')]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_count = 1 tensor_count = 1
split_dim = 0 split_dim = 0
for dim,l in enumerate(layout): for dim, l in enumerate(layout.split('_')):
if l.isdigit(): if l.isdigit():
tensor_count = int(l) tensor_count = int(l)
split_dim = dim split_dim = dim
break break
tensors = torch.split(tensor, 1, dim = split_dim) if split_dim != 0 else [tensor] tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
for j in range(tensor_count): for j in range(tensor_count):
if split_dim != 0: if split_dim != 0:
inp.append(tensors[j].squeeze(split_dim)) inp.append(tensors[j].squeeze(split_dim))
else: else:
inp.append(tensors[j]) inp.append(tensors[j])
for i in range(3): for i in range(3):
inp[i].requires_grad=True inp[i].requires_grad = True
seqlens = torch.empty(bs, dtype = torch.int32).cuda() # Create output gradient
seqlens.fill_(config.seq_len) qkv_format_kv = '_'.join(qkv_format)
cu_seqlens = torch.zeros(bs + 1, device = inp[0].device, dtype = torch.int32) qkv_format_kv = qkv_format_kv.replace('s', 'sq')
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')]
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
qkv_format_no_thd = qkv_format if qkv_format != 'thd' else 'bshd' out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
op_grad_shape = [dim_to_num[i] for i in qkv_format_no_thd]
op_grad_shape_new = [*op_grad_shape[:-2], op_grad_shape[-2] * op_grad_shape[-1]]
op_grad = 0.001 * torch.randint(0, 200, op_grad_shape_new, dtype = dtype).cuda()
# Create bias
if config.attn_bias_type in ['no_bias', 'alibi']:
bias = None
if config.attn_bias_type == 'post_scale_bias':
bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv,
dtype=dtype, device="cuda")
# Create RNG
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
# Set up model
block = ( block = (
DotProductAttention( DotProductAttention(
config.num_attention_heads, config.num_heads,
config.head_dim, config.head_dim,
attention_dropout = config.dropout_p, num_gqa_groups=config.num_gqa_groups,
attn_mask_type = config.attn_mask_type, attention_dropout=config.dropout_p,
sequence_parallel = False,
tp_size = 1,
get_rng_state_tracker = None,
tp_group = None,
layer_number = 1,
attention_type = "self"
).to(dtype = dtype).cuda()
)
if qkv_format != 'thd':
op = block(inp[0], inp[1], inp[2], qkv_format=qkv_format)
else:
cu_seqlens_q = torch.arange(
0,
(bs + 1) * config.seq_len,
step=config.seq_len,
dtype=torch.int32,
device=inp[0].device)
cu_seqlens_kv = torch.arange(
0,
(bs + 1) * config.seq_len,
step=config.seq_len,
dtype=torch.int32,
device=inp[1].device)
op = block(inp[0], inp[1], inp[2],
qkv_format=qkv_format, qkv_format=qkv_format,
cu_seqlens_q = cu_seqlens_q, attn_mask_type=config.attn_mask_type,
cu_seqlens_kv = cu_seqlens_kv) sequence_parallel=False,
op.backward(op_grad) tp_size=1,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type=config.attn_type,
).to(dtype=dtype, device="cuda")
)
return op, (inp[0].grad, inp[1].grad, inp[2].grad) # Run a forward and backward pass
out = block(inp[0], inp[1], inp[2],
attention_mask=attention_mask,
qkv_format=qkv_format,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
fast_zero_fill=True)
out.backward(out_grad)
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
model_configs_te_layer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"),
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
}
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_lean.keys()) @pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"]) @pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("fused_qkv_params", [True, False]) @pytest.mark.parametrize("qkv_format", ["sbhd"])
@pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("fused_qkv_params", [False])
def test_transformer_layer(dtype, bs, model, bias_type, fused_qkv_params, RoPE): @pytest.mark.parametrize("RoPE", [False])
"""Test TransformerLayer module when its DotProductAttention is enabled with def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE):
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend""" """Test TransformerLayer module"""
# Get configs # Get configs
config = model_configs_lean[model] config = model_configs[model]
tols = dict(atol=5e-1, rtol=5e-2) tols = dict(atol=5e-1, rtol=5e-2)
workspace_opt = True
# TODO @cyanguwa: Handle test cases more cleanly
if config.hidden_size > 1024:
pytest.skip(
"Tolerances for test_transformer_layer are intended for small test cases"
)
# Skip if only unfused backend is supported # Skip if only unfused backend is supported
fused_attn_supported = _is_fused_attention_supported( if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, config,
dtype, dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd", qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
bias_type=bias_type,
) )
flash_attn_supported = _is_flash_attention_supported(bias_type=bias_type) flash_attn_supported = _is_flash_attention_supported(config)
if not (fused_attn_supported or flash_attn_supported): unfused_attn_supported = _is_unfused_attention_supported(config)
pytest.skip( if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
"Neither FusedAttention nor FlashAttention support this model config" pytest.skip("Less than two backends to compare.")
)
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer( if unfused_attn_supported:
dtype, unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
bs, dtype,
config, config,
"UnfusedDotProductAttention", "UnfusedDotProductAttention",
bias_type, ckpt_attn,
fused_qkv_params, qkv_format,
RoPE, workspace_opt,
) fused_qkv_params,
RoPE,
)
# FusedAttention backend # FusedAttention backend
if fused_attn_supported: if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer( fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype, dtype,
bs,
config, config,
"FusedAttention", "FusedAttention",
bias_type, ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params, fused_qkv_params,
RoPE, RoPE,
) )
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
# FlashAttention backend # FlashAttention backend
if flash_attn_supported: if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer( flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
dtype, dtype,
bs,
config, config,
"FlashAttention", "FlashAttention",
bias_type, ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params, fused_qkv_params,
RoPE, RoPE,
) )
if unfused_attn_supported and fused_attn_supported:
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
if unfused_attn_supported and flash_attn_supported:
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols) torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
def _run_transformer_layer(dtype, bs, config, backend, bias_type, fused_qkv_params, RoPE): @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype=dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype=torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
sigma = 0.02
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
layer_number = 1
drop_path_rate = 0.0
drop_path_rates = [
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
if bias_type != "no_bias":
bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
dtype=dtype).cuda()
else:
bias = None
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim)
rotary_pos_emb = PE(config.seq_len).cuda().to(dtype=dtype)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=1e-5,
hidden_dropout=0.0,
attention_dropout=config.dropout_p,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=config.head_dim,
tp_group=None,
tp_size=1,
params_dtype=dtype,
get_rng_state_tracker=None,
fuse_wgrad_accumulation=False,
seq_length=config.seq_len,
micro_batch_size=bs,
sequence_parallel=False,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="encoder",
drop_path_rate=drop_path_rates[layer_number - 1],
set_parallel_mode=True,
fuse_qkv_params=fused_qkv_params,
zero_centered_gamma=False,
qkv_weight_interleaved=False,
ub_tp_comm_overlap=False,
bias=True,
)
.to(dtype=dtype)
.cuda()
)
num_iters = 5
for i in range(num_iters):
op = block(inp, self_attn_mask_type=config.attn_mask_type,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=bias_type,
core_attention_bias=bias)
loss = op.sum()
loss.backward()
return op, inp.grad
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("bs", batch_sizes_lean) @pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_lean.keys()) @pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
def test_transformer_layer_gqa(dtype, bs, model): def test_te_layer_misc(dtype, model_configs, model):
"""Test TransformerLayer module when its DotProductAttention is enabled with """Test TransformerLayer module with miscellanous settings"""
FlashAttention or UnfusedDotProductAttention backend""" ckpt_attn = True
qkv_format = "bshd"
config = model_configs_lean[model] fused_qkv_params = True
RoPE = True
test_transformer_layer(dtype, model_configs, model,
ckpt_attn, qkv_format, fused_qkv_params, RoPE)
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
def test_te_layer_mqa_gqa(dtype, model_configs, model):
"""Test TransformerLayer module with MQA/GQA"""
def find_factors(x): def find_factors(x):
f = [] f = []
for i in range(1, x + 1): for i in range(2, x + 1):
if x % i == 0: if x % i == 0:
f.append(i) f.append(i)
return f return f
# Skip if only unfused backend is supported ckpt_attn = True
if not (_flash_attn_2_available and _is_flash_attention_supported()): qkv_format = "bshd"
pytest.skip("FlashAttention does not support this model config") fused_qkv_params = True
RoPE = True
num_querys_per_gqa_group = find_factors(config.num_attention_heads) config = model_configs[model]
num_querys_per_gqa_group = find_factors(config.num_heads)
for num_q_per_gqa_group in num_querys_per_gqa_group: for num_q_per_gqa_group in num_querys_per_gqa_group:
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer_gqa( config.num_gqa_groups=config.num_heads // num_q_per_gqa_group
dtype, bs, config, "FlashAttention", num_q_per_gqa_group) test_transformer_layer(dtype, model_configs, model,
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer_gqa( ckpt_attn, qkv_format, fused_qkv_params, RoPE)
dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group)
def _run_transformer_layer(
atol, rtol = 5e-1, 5e-2 dtype: torch.dtype,
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) config: ModelConfig,
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) backend: str,
ckpt_attn: bool,
def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group): qkv_layout: str,
workspace_opt: bool,
fused_qkv_params: bool,
RoPE: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run TransformerLayer module with one forward pass and one backward pass"""
# Set RNG and environment variables
reset_rng_states() reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
...@@ -598,17 +675,27 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr ...@@ -598,17 +675,27 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
inp = torch.randn( # Create input tensor
config.seq_len, bs, config.num_attention_heads * config.head_dim, inp = torch.randn(config.max_seqlen_q, config.batch_size, config.hidden_size,
dtype=dtype).cuda() dtype=dtype, device="cuda", requires_grad = True)
inp.requires_grad=True
seqlens = torch.empty(bs, dtype=torch.int32).cuda() # Create seqlens
seqlens.fill_(config.seq_len) if "padding" in config.attn_mask_type:
cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32) seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) dtype=torch.int32, device="cuda")
op_grad = torch.randn( else:
config.seq_len, bs, config.num_attention_heads * config.head_dim, seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
dtype=dtype).cuda() dtype=torch.int32, device="cuda")
# Create attention mask if padding
attention_mask = None
if "padding" in config.attn_mask_type:
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat([attention_mask_q,
torch.Tensor([True]*seqlens_q[i] + [False]*(config.max_seqlen_q-seqlens_q[i]))
.to(torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask = attention_mask_q.to(device="cuda")
sigma = 0.02 sigma = 0.02
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -619,12 +706,43 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr ...@@ -619,12 +706,43 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr
drop_path_rates = [ drop_path_rates = [
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)] rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
# Create bias
if config.attn_bias_type == 'no_bias':
bias = None
if config.attn_bias_type == 'post_scale_bias':
bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv,
dtype=dtype, device="cuda")
elif config.attn_bias_type == 'alibi':
if os.environ['NVTE_FUSED_ATTN_BACKEND'] == '0':
config.attn_bias_type = 'post_scale_bias'
n = 2 ** math.floor(math.log2(config.num_heads))
m_0 = 2.0 ** (-8.0 / n)
m = torch.pow(m_0, torch.arange(1, 1 + n))
a = torch.ones(config.max_seqlen_q, config.max_seqlen_kv)
b = torch.triu(a,diagonal=1)
c = b.cumsum(dim=-1)
d = c - torch.transpose(c, 0, 1)
bias = d.expand(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv)
for i in range(config.num_heads):
bias[0,i,:,:] = m[i] * bias[0,i,:,:]
bias = bias.to(dtype=dtype, device="cuda")
else:
bias = None
# Create RoPE
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim)
rotary_pos_emb = PE(config.max_seqlen_q).to(dtype=dtype, device="cuda")
# Set up model
block = ( block = (
TransformerLayer( TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_heads,
num_gqa_groups=config.num_attention_heads / num_querys_per_gqa_group, num_gqa_groups=config.num_gqa_groups,
layernorm_epsilon=1e-5, layernorm_epsilon=1e-5,
hidden_dropout=0.0, hidden_dropout=0.0,
attention_dropout=config.dropout_p, attention_dropout=config.dropout_p,
...@@ -632,79 +750,85 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr ...@@ -632,79 +750,85 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
layer_number=layer_number, layer_number=layer_number,
kv_channels=config.head_dim, kv_channels=config.head_dim,
self_attn_mask_type=config.attn_mask_type,
tp_group=None, tp_group=None,
tp_size= 1, tp_size=1,
params_dtype=dtype, params_dtype=dtype,
get_rng_state_tracker=None, get_rng_state_tracker=None,
fuse_wgrad_accumulation=False, fuse_wgrad_accumulation=False,
seq_length=config.seq_len, seq_length=config.max_seqlen_q,
micro_batch_size=bs, micro_batch_size=config.batch_size,
sequence_parallel=False, sequence_parallel=False,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
layer_type="encoder", layer_type="encoder",
drop_path_rate=drop_path_rates[layer_number - 1], drop_path_rate=drop_path_rates[layer_number - 1],
set_parallel_mode=True, set_parallel_mode=True,
fuse_qkv_params=True, fuse_qkv_params=fused_qkv_params,
zero_centered_gamma=False, zero_centered_gamma=False,
qkv_weight_interleaved=False, qkv_weight_interleaved=False,
ub_tp_comm_overlap=False, ub_tp_comm_overlap=False,
bias=True, bias=True,
) )
.to(dtype=dtype) .to(dtype=dtype, device="cuda")
.cuda()
) )
op = block(inp, self_attn_mask_type=config.attn_mask_type) # Run a forward and backward pass
op.backward(op_grad) out = block(inp,
attention_mask=attention_mask,
self_attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias)
loss = out.sum()
loss.backward()
return out, inp.grad
return op, inp.grad
model_configs_fp8 = { model_configs_fp8 = {
"test1": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"), # test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_1": ModelConfig(1, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
} }
batch_sizes_fp8 = [1, 4]
param_types_fp8 = [torch.float16] param_types_fp8 = [torch.float16]
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.parametrize("dtype", param_types_fp8) @pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("bs", batch_sizes_fp8)
@pytest.mark.parametrize("model", model_configs_fp8.keys()) @pytest.mark.parametrize("model", model_configs_fp8.keys())
def test_dpa_fp8(dtype, bs, model): def test_dpa_fp8(dtype, model):
"""Test FP8 dot-product attention with different backends """Test FP8 dot product attention
FusedAttention uses fused_attn_fwd/bwd_qkvpacked from FusedAttention uses fused_attn_fwd/bwd_qkvpacked from cpp_extensions,
cpp_extensions. UnfusedDotProductAttention uses plain PyTorch and UnfusedDotProductAttention uses plain PyTorch operations in FP16
operations. and converts inputs/outputs from/to FP8.
""" """
config = model_configs_fp8[model] config = model_configs_fp8[model]
# Skip if not supported # Skip if not supported
if not _is_fused_attention_supported(config, dtype): fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype)
if not fused_attn_supported:
pytest.skip("FusedAttention does not support this model config") pytest.skip("FusedAttention does not support this model config")
# Run dot-product attention with different backends # Run dot-product attention with different backends
fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8( fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8(
dtype, dtype, config, "FusedAttention")
bs,
config,
"FusedAttention"
)
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref( unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref(
dtype, dtype, config, "UnfusedDotProductAttention")
bs,
config,
"UnfusedDotProductAttention",
)
# Check that results match
tols = dict(atol=2.5e-2, rtol=2.5e-2) tols = dict(atol=2.5e-2, rtol=2.5e-2)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols) torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
def _run_dpa_fp8(dtype, bs, config, backend): def _run_dpa_fp8(dtype, config, backend):
"""Run FusedAttention FP8 backend, i.e.
fused_attn_fwd/bwd_qkvpacked from cpp_extensions"""
reset_rng_states() reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
...@@ -715,17 +839,16 @@ def _run_dpa_fp8(dtype, bs, config, backend): ...@@ -715,17 +839,16 @@ def _run_dpa_fp8(dtype, bs, config, backend):
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.01 * torch.randn( inp = 0.01 * torch.randn(
bs * config.seq_len, config.num_attention_heads * config.head_dim, config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim,
dtype=dtype).cuda() dtype=dtype, device="cuda", requires_grad=True)
inp.requires_grad=True seqlens = torch.full([config.batch_size], config.max_seqlen_q,
seqlens = torch.empty(bs, dtype=torch.int32).cuda() dtype=torch.int32, device="cuda")
seqlens.fill_(config.seq_len) cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = 0.01 * torch.randn( out_grad = 0.01 * torch.randn(
bs * config.seq_len, config.num_attention_heads * config.head_dim, config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim,
dtype=dtype).cuda() dtype=dtype, device="cuda")
torch.save(op_grad, 'op_grad.pt') torch.save(out_grad, 'out_grad.pt')
fp8_recipe = recipe.DelayedScaling( fp8_recipe = recipe.DelayedScaling(
margin=0, margin=0,
...@@ -735,17 +858,21 @@ def _run_dpa_fp8(dtype, bs, config, backend): ...@@ -735,17 +858,21 @@ def _run_dpa_fp8(dtype, bs, config, backend):
amax_compute_algo="most_recent", amax_compute_algo="most_recent",
) )
dpa = DPA_FP8(config).to(dtype=torch.float16).cuda() dpa = DPA_FP8(config).to(dtype=torch.float16, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
op = dpa(inp, cu_seqlens, config.seq_len) out = dpa(inp, cu_seqlens, config.max_seqlen_q)
op.backward(op_grad) out.backward(out_grad)
context = torch.load("ctx.pt") context = torch.load("ctx.pt")
dqkv = torch.load('dqkv.pt') dqkv = torch.load('dqkv.pt')
return (context.view(bs, config.seq_len, -1).transpose(0,1), return (context.view(config.batch_size, config.max_seqlen_q, -1).transpose(0,1),
dqkv.view(bs, config.seq_len, 3, config.num_attention_heads, config.head_dim).transpose(0,1).contiguous()) dqkv.view(config.batch_size, config.max_seqlen_q, 3,
config.num_heads, config.head_dim).transpose(0,1).contiguous())
def _run_dpa_fp8_ref(dtype, bs, config, backend): def _run_dpa_fp8_ref(dtype, config, backend):
"""Run UnfusedDotProductAttention as a reference, i.e.
plain PyTorch implementation in FP16 and inputs/outputs
are converted from/to FP8"""
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
...@@ -754,13 +881,20 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend): ...@@ -754,13 +881,20 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
inp = torch.load('qkv.pt').cuda() inp = torch.load('qkv.pt').to(device="cuda")
inp.requires_grad=True inp.requires_grad = True
seqlens = torch.empty(bs, dtype=torch.int32).cuda() seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
seqlens.fill_(config.seq_len) cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1) out_grad = torch.load('out_grad.pt').to(device="cuda").view(
config.batch_size, config.max_seqlen_q, -1).transpose(0,1)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
...@@ -771,7 +905,7 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend): ...@@ -771,7 +905,7 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
block = ( block = (
DotProductAttention( DotProductAttention(
config.num_attention_heads, config.num_heads,
config.head_dim, config.head_dim,
attention_dropout=config.dropout_p, attention_dropout=config.dropout_p,
sequence_parallel=False, sequence_parallel=False,
...@@ -780,16 +914,17 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend): ...@@ -780,16 +914,17 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
tp_group=None, tp_group=None,
layer_number=1, layer_number=1,
attention_type="self" attention_type="self"
).to(dtype=dtype).cuda() ).to(dtype=dtype, device="cuda")
) )
q = inp[:, :,0,:,:] q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:] k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:] v = inp[:, :,2,:,:]
op = block(q, k, v, attn_mask_type=config.attn_mask_type) out = block(q, k, v, attn_mask_type=config.attn_mask_type)
op.backward(op_grad) out.backward(out_grad)
return out, inp.grad
return op, inp.grad
_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432 # 32MiB _CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432 # 32MiB
_2X_ACC_FPROP = False _2X_ACC_FPROP = False
...@@ -812,7 +947,7 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -812,7 +947,7 @@ class _dpa_fp8(torch.autograd.Function):
qkv_weight: torch.Tensor, qkv_weight: torch.Tensor,
qkv_bias: torch.Tensor, qkv_bias: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
num_attention_heads: int, num_heads: int,
p_dropout: float, p_dropout: float,
max_s: int, max_s: int,
fast_zero_fill: bool, fast_zero_fill: bool,
...@@ -823,7 +958,7 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -823,7 +958,7 @@ class _dpa_fp8(torch.autograd.Function):
assert inp.dim() == 2 assert inp.dim() == 2
in_features = qkv_weight.shape[-1] in_features = qkv_weight.shape[-1]
h = num_attention_heads h = num_heads
d = in_features // h d = in_features // h
b = cu_seqlens.numel() - 1 b = cu_seqlens.numel() - 1
is_nl = False is_nl = False
...@@ -921,7 +1056,7 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -921,7 +1056,7 @@ class _dpa_fp8(torch.autograd.Function):
ctx.fast_zero_fill = fast_zero_fill ctx.fast_zero_fill = fast_zero_fill
ctx.is_nl = is_nl ctx.is_nl = is_nl
ctx.hidden_size = in_features ctx.hidden_size = in_features
ctx.num_attention_heads = num_attention_heads ctx.num_heads = num_heads
context_fp16 = ext.cast_from_fp8(context, fp8_meta["scaling_fwd"], context_fp16 = ext.cast_from_fp8(context, fp8_meta["scaling_fwd"],
META_O, fp8_dtype_forward, tex.DType.kFloat16) META_O, fp8_dtype_forward, tex.DType.kFloat16)
...@@ -1050,7 +1185,7 @@ class DPA_FP8(TransformerEngineBaseModule): ...@@ -1050,7 +1185,7 @@ class DPA_FP8(TransformerEngineBaseModule):
params_dtype: torch.dtype = torch.float32): params_dtype: torch.dtype = torch.float32):
super().__init__() super().__init__()
self.p_dropout = config.dropout_p self.p_dropout = config.dropout_p
self.h = config.num_attention_heads self.h = config.num_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_dim = config.head_dim self.head_dim = config.head_dim
self.fast_zero_fill = True self.fast_zero_fill = True
......
...@@ -508,6 +508,7 @@ def _test_e2e_checkpointing_get_model(config, dtype): ...@@ -508,6 +508,7 @@ def _test_e2e_checkpointing_get_model(config, dtype):
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
return ( return (
TransformerLayer( TransformerLayer(
config.hidden_size, config.hidden_size,
...@@ -524,7 +525,6 @@ def _test_e2e_checkpointing_get_model(config, dtype): ...@@ -524,7 +525,6 @@ def _test_e2e_checkpointing_get_model(config, dtype):
params_dtype=dtype, params_dtype=dtype,
) )
.cuda() .cuda()
.eval()
) )
...@@ -559,9 +559,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= ...@@ -559,9 +559,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
if p.requires_grad: if p.requires_grad:
param_grads.append(p.grad.clone()) param_grads.append(p.grad.clone())
global _cpu_rng_state, _cuda_rng_state
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
del block del block
block = _test_e2e_checkpointing_get_model(config, dtype) block = _test_e2e_checkpointing_get_model(config, dtype)
block.load_state_dict(torch.load(path)) block.load_state_dict(torch.load(path))
reset_rng_states()
for p in block.parameters(): for p in block.parameters():
if p.requires_grad: if p.requires_grad:
...@@ -815,21 +820,19 @@ def test_dpa_accuracy(dtype, bs, model): ...@@ -815,21 +820,19 @@ def test_dpa_accuracy(dtype, bs, model):
DotProductAttention( DotProductAttention(
config.num_attention_heads, config.num_attention_heads,
config.embed, config.embed,
attention_dropout=0.1, # dropout attention_dropout=0.0, # disable dropout, FU uses rng differently
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
.eval()
) )
torch_dpa = ( torch_dpa = (
TorchDotProductAttention( TorchDotProductAttention(
config.embed, config.embed,
0.1, # dropout 0.0, # dropout
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
.eval()
) )
te_outputs = _test_dpa_accuracy(te_dpa, bs, dtype, config) te_outputs = _test_dpa_accuracy(te_dpa, bs, dtype, config)
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "fused_attn_f16_arbitrary_seqlen.h" #include "fused_attn_f16_arbitrary_seqlen.h"
#include "fused_attn_fp8.h" #include "fused_attn_fp8.h"
#include "../util/cuda_runtime.h" #include "../util/cuda_runtime.h"
#include "../util/system.h"
// map NVTE_QKV_Layout to NVTE_QKV_Layout_Group // map NVTE_QKV_Layout to NVTE_QKV_Layout_Group
NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
...@@ -18,7 +19,6 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { ...@@ -18,7 +19,6 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_SB3HD: case NVTE_QKV_Layout::NVTE_SB3HD:
case NVTE_QKV_Layout::NVTE_BS3HD: case NVTE_QKV_Layout::NVTE_BS3HD:
case NVTE_QKV_Layout::NVTE_T3HD: case NVTE_QKV_Layout::NVTE_T3HD:
case NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED:
return NVTE_QKV_Layout_Group::NVTE_3HD; return NVTE_QKV_Layout_Group::NVTE_3HD;
case NVTE_QKV_Layout::NVTE_SBH3D: case NVTE_QKV_Layout::NVTE_SBH3D:
case NVTE_QKV_Layout::NVTE_BSH3D: case NVTE_QKV_Layout::NVTE_BSH3D:
...@@ -27,7 +27,6 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { ...@@ -27,7 +27,6 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
case NVTE_QKV_Layout::NVTE_THD_T2HD: case NVTE_QKV_Layout::NVTE_THD_T2HD:
case NVTE_QKV_Layout::NVTE_KV_INTERLEAVED:
return NVTE_QKV_Layout_Group::NVTE_HD_2HD; return NVTE_QKV_Layout_Group::NVTE_HD_2HD;
case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
...@@ -36,7 +35,6 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { ...@@ -36,7 +35,6 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_THD_THD_THD: case NVTE_QKV_Layout::NVTE_THD_THD_THD:
case NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED:
return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD; return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD;
default: default:
NVTE_ERROR("qkv_layout not supported!"); NVTE_ERROR("qkv_layout not supported!");
...@@ -63,9 +61,6 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { ...@@ -63,9 +61,6 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_THD_T2HD: case NVTE_QKV_Layout::NVTE_THD_T2HD:
case NVTE_QKV_Layout::NVTE_THD_TH2D: case NVTE_QKV_Layout::NVTE_THD_TH2D:
case NVTE_QKV_Layout::NVTE_THD_THD_THD: case NVTE_QKV_Layout::NVTE_THD_THD_THD:
case NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED:
case NVTE_QKV_Layout::NVTE_KV_INTERLEAVED:
case NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED:
return NVTE_QKV_Format::NVTE_THD; return NVTE_QKV_Format::NVTE_THD;
default: default:
NVTE_ERROR("qkv_layout not supported!"); NVTE_ERROR("qkv_layout not supported!");
...@@ -79,8 +74,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -79,8 +74,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_QKV_Layout qkv_layout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Mask_Type attn_mask_type,
float dropout, size_t max_seqlen_q, float dropout,
size_t max_seqlen_kv, size_t head_dim) { size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv,
size_t head_dim) {
using namespace transformer_engine; using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device(); const int device_id = cuda::current_device();
...@@ -91,56 +88,66 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -91,56 +88,66 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2) if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)
&& (sm_arch_ >= 90) && (sm_arch_ >= 90)
&& (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q == max_seqlen_kv)
&& (num_attn_heads == num_gqa_groups)
&& (max_seqlen_q <= 512) && (max_seqlen_q <= 512)
&& (head_dim == 64) && (head_dim == 64)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD)) {
|| (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD))) { if (cudnn_runtime_version >= 8900) {
#if (CUDNN_VERSION >= 8900) backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
backend = NVTE_Fused_Attn_Backend::NVTE_FP8; } else {
#else backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+."
std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." " Please upgrade your cuDNN version if possible." << std::endl;
" Please upgrade your cuDNN version if possible." << std::endl; }
#endif
} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) {
bool flag_m512 = false; bool flag_m512 = false;
bool flag_arb = false; bool flag_arb = false;
if ((sm_arch_ == 80 || sm_arch_ == 90) if ((sm_arch_ == 80 || sm_arch_ == 90)
&& (max_seqlen_q <= 512)
&& (max_seqlen_kv <= 512)
&& (head_dim == 64) && (head_dim == 64)
&& (num_attn_heads == num_gqa_groups)
&& ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
&& ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK
&& max_seqlen_q == max_seqlen_kv)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) && ((qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD) || (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) || (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) { || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) {
flag_m512 = true; flag_m512 = true;
} }
if ( if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80)
#if (CUDNN_VERSION >= 8903) || (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)))
(sm_arch_ >= 80) && (max_seqlen_q % 64 == 0)
#else && (max_seqlen_kv % 64 == 0)
(sm_arch_ == 80 || sm_arch_ == 90) && ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups)
#endif || (cudnn_runtime_version >= 8907))
&& (max_seqlen_q == max_seqlen_kv) && ((head_dim <= 128) && (head_dim % 8 == 0))
&& ((head_dim == 64) || (head_dim == 128)) && ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((cudnn_runtime_version >= 8906 && sm_arch_ == 90)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS
|| (bias_type == NVTE_Bias_Type::NVTE_ALIBI
&& attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK)
|| bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)))
&& ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) && ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| ((cudnn_runtime_version >= 8906) && || ((cudnn_runtime_version >= 8906)
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK))) || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) || attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)))
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) && (!(cudnn_runtime_version >= 8906
|| (qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD))) { && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK
|| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
&& bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
&& ((qkv_format == NVTE_QKV_Format::NVTE_SBHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) {
flag_arb = true; flag_arb = true;
} }
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) if (((max_seqlen_q > 512) || (max_seqlen_kv > 512))
...@@ -148,34 +155,32 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -148,34 +155,32 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
} }
if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
if (flag_m512 == true) { if (flag_arb == true) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen;
} else if ((flag_m512 == false) && (flag_arb == true)) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
} else if ((flag_arb == false) && (flag_m512 == true)) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen;
}
int env_backend = static_cast<int>(backend);
env_backend = transformer_engine::getenv<int>("NVTE_FUSED_ATTN_BACKEND", env_backend);
if (((env_backend == static_cast<int>(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen))
&& flag_m512)
|| ((env_backend == static_cast<int>(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen))
&& flag_arb)) {
backend = static_cast<NVTE_Fused_Attn_Backend>(env_backend);
} }
} }
const char* env_backend = std::getenv("NVTE_FUSED_ATTN_BACKEND"); if (cudnn_runtime_version < 8901
if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512) && backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
&& (flag_arb == true)
&& (env_backend != nullptr)
&& (std::string(env_backend) == std::to_string(
NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen))) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
#if (CUDNN_VERSION < 8901)
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+." std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+."
" Please upgrade your cuDNN version if possible." << std::endl; " Please upgrade your cuDNN version if possible." << std::endl;
} }
#endif if (cudnn_runtime_version < 8900
#if (CUDNN_VERSION < 8900) && backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+." std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+."
" Please upgrade your cuDNN version if possible." << std::endl; " Please upgrade your cuDNN version if possible." << std::endl;
} }
#endif
} else { } else {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
} }
...@@ -208,10 +213,17 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -208,10 +213,17 @@ void nvte_fused_attn_fwd_qkvpacked(
Tensor *output_O = reinterpret_cast<Tensor*>(O); Tensor *output_O = reinterpret_cast<Tensor*>(O);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace); Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// QKV shape is [total_seqs, 3, h, d]
auto ndim = input_QKV->data.shape.size(); auto ndim = input_QKV->data.shape.size();
size_t b = input_cu_seqlens->data.shape[0] - 1; size_t b = input_cu_seqlens->data.shape[0] - 1;
size_t h = input_QKV->data.shape[ndim - 2]; size_t h = 0;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
h = input_QKV->data.shape[ndim - 2];
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
h = input_QKV->data.shape[ndim - 3];
} else {
NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!");
}
size_t d = input_QKV->data.shape[ndim - 1]; size_t d = input_QKV->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
...@@ -221,12 +233,12 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -221,12 +233,12 @@ void nvte_fused_attn_fwd_qkvpacked(
nvte_get_fused_attn_backend( nvte_get_fused_attn_backend(
QKV_type, QKV_type, QKV_type, QKV_type,
qkv_layout, bias_type, attn_mask_type, qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen, max_seqlen, d); dropout, h, h, max_seqlen, max_seqlen, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_qkvpacked( fused_attn_max_512_fwd_qkvpacked(
b, max_seqlen, h, d, b, h, max_seqlen, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_Bias, output_O, input_QKV, input_Bias, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
...@@ -239,7 +251,7 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -239,7 +251,7 @@ void nvte_fused_attn_fwd_qkvpacked(
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, max_seqlen, h, d, b, h, max_seqlen, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_Bias, output_O, input_QKV, input_Bias, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
...@@ -253,7 +265,7 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -253,7 +265,7 @@ void nvte_fused_attn_fwd_qkvpacked(
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd_qkvpacked( fused_attn_fp8_fwd_qkvpacked(
b, max_seqlen, h, d, b, h, max_seqlen, d,
is_training, attn_scale, dropout, qkv_layout, is_training, attn_scale, dropout, qkv_layout,
input_QKV, input_output_S, output_O, input_QKV, input_output_S, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
...@@ -297,10 +309,17 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -297,10 +309,17 @@ void nvte_fused_attn_bwd_qkvpacked(
Tensor *output_dBias = reinterpret_cast<Tensor*>(dBias); Tensor *output_dBias = reinterpret_cast<Tensor*>(dBias);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace); Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// QKV shape is [total_seqs, 3, h, d]
auto ndim = input_QKV->data.shape.size(); auto ndim = input_QKV->data.shape.size();
size_t b = input_cu_seqlens->data.shape[0] - 1; size_t b = input_cu_seqlens->data.shape[0] - 1;
size_t h = input_QKV->data.shape[ndim - 2]; size_t h = 0;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
h = input_QKV->data.shape[ndim - 2];
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
h = input_QKV->data.shape[ndim - 3];
} else {
NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!");
}
size_t d = input_QKV->data.shape[ndim - 1]; size_t d = input_QKV->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
...@@ -310,13 +329,13 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -310,13 +329,13 @@ void nvte_fused_attn_bwd_qkvpacked(
nvte_get_fused_attn_backend( nvte_get_fused_attn_backend(
QKV_type, QKV_type, QKV_type, QKV_type,
qkv_layout, bias_type, attn_mask_type, qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen, max_seqlen, d); dropout, h, h, max_seqlen, max_seqlen, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_qkvpacked( fused_attn_max_512_bwd_qkvpacked(
b, max_seqlen, h, d, b, h, max_seqlen, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_dO, input_QKV, input_dO,
output_S, output_S,
...@@ -329,11 +348,17 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -329,11 +348,17 @@ void nvte_fused_attn_bwd_qkvpacked(
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor *input_Bias, *input_rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
input_Bias = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
} else {
input_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, max_seqlen, h, d, b, h, max_seqlen, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_O, input_dO, input_QKV, input_O, input_dO, input_Bias,
output_S, output_S,
output_dQKV, output_dBias, output_dQKV, output_dBias,
input_cu_seqlens, input_rng_state, input_cu_seqlens, input_rng_state,
...@@ -350,7 +375,7 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -350,7 +375,7 @@ void nvte_fused_attn_bwd_qkvpacked(
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]); const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_qkvpacked( fused_attn_fp8_bwd_qkvpacked(
b, max_seqlen, h, d, b, h, max_seqlen, d,
attn_scale, dropout, qkv_layout, attn_scale, dropout, qkv_layout,
input_QKV, input_O, input_dO, input_QKV, input_O, input_dO,
input_M, input_ZInv, input_M, input_ZInv,
...@@ -395,12 +420,20 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -395,12 +420,20 @@ void nvte_fused_attn_fwd_kvpacked(
Tensor *output_O = reinterpret_cast<Tensor*>(O); Tensor *output_O = reinterpret_cast<Tensor*>(O);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace); Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// Q shape is [total_seqs, h, d]
// KV shape is [total_seqs, h, d]
auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[ndim - 2]; auto ndim = input_Q->data.shape.size();
size_t h_q = input_Q->data.shape[ndim - 2];
size_t d = input_Q->data.shape[ndim - 1]; size_t d = input_Q->data.shape[ndim - 1];
auto ndim_kv = input_KV->data.shape.size();
size_t h_kv = 0;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
h_kv = input_KV->data.shape[ndim_kv - 2];
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
h_kv = input_KV->data.shape[ndim_kv - 3];
} else {
NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!");
}
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
...@@ -410,12 +443,12 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -410,12 +443,12 @@ void nvte_fused_attn_fwd_kvpacked(
nvte_get_fused_attn_backend( nvte_get_fused_attn_backend(
Q_type, KV_type, Q_type, KV_type,
qkv_layout, bias_type, attn_mask_type, qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen_q, max_seqlen_kv, d); dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_kvpacked( fused_attn_max_512_fwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, h, d, b, h_q, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_Bias, output_O, input_Q, input_KV, input_Bias, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
...@@ -426,10 +459,19 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -426,10 +459,19 @@ void nvte_fused_attn_fwd_kvpacked(
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
const char* err_msg = #if (CUDNN_VERSION >= 8903)
"The FP16/BF16 fused attention (arbitrary seqlen) currently " fused_attn_arbitrary_seqlen_fwd_kvpacked(
"only supports packed QKV input.\n"; b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
NVTE_ERROR(err_msg); is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else { } else {
...@@ -471,12 +513,20 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -471,12 +513,20 @@ void nvte_fused_attn_bwd_kvpacked(
Tensor *output_dBias = reinterpret_cast<Tensor*>(dBias); Tensor *output_dBias = reinterpret_cast<Tensor*>(dBias);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace); Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// Q shape is [total_seqs, h, d]
// KV shape is [total_seqs, h, d]
auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[ndim - 2]; auto ndim = input_Q->data.shape.size();
size_t h_q = input_Q->data.shape[ndim - 2];
size_t d = input_Q->data.shape[ndim - 1]; size_t d = input_Q->data.shape[ndim - 1];
auto ndim_kv = input_KV->data.shape.size();
size_t h_kv = 0;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
h_kv = input_KV->data.shape[ndim_kv - 2];
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
h_kv = input_KV->data.shape[ndim_kv - 3];
} else {
NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!");
}
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
...@@ -486,13 +536,13 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -486,13 +536,13 @@ void nvte_fused_attn_bwd_kvpacked(
nvte_get_fused_attn_backend( nvte_get_fused_attn_backend(
Q_type, KV_type, Q_type, KV_type,
qkv_layout, bias_type, attn_mask_type, qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen_q, max_seqlen_kv, d); dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_kvpacked( fused_attn_max_512_bwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, h, d, b, h_q, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_dO, input_Q, input_KV, input_dO,
output_S, output_S,
...@@ -503,10 +553,29 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -503,10 +553,29 @@ void nvte_fused_attn_bwd_kvpacked(
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
const char* err_msg = #if (CUDNN_VERSION >= 8903)
"The FP16/BF16 fused attention (arbitrary seqlen) currently " Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
"only supports packed QKV input.\n"; Tensor *input_Bias, *input_rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
input_Bias = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
} else {
input_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_O, input_dO, input_Bias,
output_S,
output_dQ, output_dKV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention "
"with arbitrary sequence length. \n";
NVTE_ERROR(err_msg); NVTE_ERROR(err_msg);
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else { } else {
...@@ -546,7 +615,8 @@ void nvte_fused_attn_fwd( ...@@ -546,7 +615,8 @@ void nvte_fused_attn_fwd(
auto ndim = input_Q->data.shape.size(); auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[ndim - 2]; size_t h_q = input_Q->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim - 2];
size_t d = input_Q->data.shape[ndim - 1]; size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
...@@ -557,12 +627,12 @@ void nvte_fused_attn_fwd( ...@@ -557,12 +627,12 @@ void nvte_fused_attn_fwd(
nvte_get_fused_attn_backend( nvte_get_fused_attn_backend(
Q_type, KV_type, Q_type, KV_type,
qkv_layout, bias_type, attn_mask_type, qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen_q, max_seqlen_kv, d); dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd( fused_attn_max_512_fwd(
b, max_seqlen_q, max_seqlen_kv, h, d, b, h_q, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_Bias, output_O, input_Q, input_K, input_V, input_Bias, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
...@@ -575,7 +645,7 @@ void nvte_fused_attn_fwd( ...@@ -575,7 +645,7 @@ void nvte_fused_attn_fwd(
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd(
b, max_seqlen_q, max_seqlen_kv, h, d, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_Bias, output_O, input_Q, input_K, input_V, input_Bias, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
...@@ -589,7 +659,7 @@ void nvte_fused_attn_fwd( ...@@ -589,7 +659,7 @@ void nvte_fused_attn_fwd(
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd( fused_attn_fp8_fwd(
b, max_seqlen_q, max_seqlen_kv, h, d, b, h_q, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout, is_training, attn_scale, dropout, qkv_layout,
input_Q, input_K, input_V, input_output_S, output_O, input_Q, input_K, input_V, input_output_S, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
...@@ -644,7 +714,8 @@ void nvte_fused_attn_bwd( ...@@ -644,7 +714,8 @@ void nvte_fused_attn_bwd(
auto ndim = input_Q->data.shape.size(); auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[ndim - 2]; size_t h_q = input_Q->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim - 2];
size_t d = input_Q->data.shape[ndim - 1]; size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
...@@ -655,13 +726,13 @@ void nvte_fused_attn_bwd( ...@@ -655,13 +726,13 @@ void nvte_fused_attn_bwd(
nvte_get_fused_attn_backend( nvte_get_fused_attn_backend(
Q_type, KV_type, Q_type, KV_type,
qkv_layout, bias_type, attn_mask_type, qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen_q, max_seqlen_kv, d); dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd( fused_attn_max_512_bwd(
b, max_seqlen_q, max_seqlen_kv, h, d, b, h_q, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_dO, input_Q, input_K, input_V, input_dO,
output_S, output_S,
...@@ -674,11 +745,17 @@ void nvte_fused_attn_bwd( ...@@ -674,11 +745,17 @@ void nvte_fused_attn_bwd(
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor *input_Bias, *input_rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
input_Bias = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
} else {
input_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd(
b, max_seqlen_q, max_seqlen_kv, h, d, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_O, input_dO, input_Q, input_K, input_V, input_O, input_dO, input_Bias,
output_S, output_S,
output_dQ, output_dK, output_dV, output_dBias, output_dQ, output_dK, output_dV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q, input_cu_seqlens_kv,
...@@ -695,7 +772,7 @@ void nvte_fused_attn_bwd( ...@@ -695,7 +772,7 @@ void nvte_fused_attn_bwd(
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]); const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd( fused_attn_fp8_bwd(
b, max_seqlen_q, max_seqlen_kv, h, d, b, h_q, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, attn_scale, dropout, qkv_layout,
input_Q, input_K, input_V, input_O, input_dO, input_Q, input_K, input_V, input_O, input_dO,
input_M, input_ZInv, input_M, input_ZInv,
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cudnn_frontend.h> #include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#include <map> #include <map>
#include <vector> #include <vector>
...@@ -46,1476 +47,613 @@ ...@@ -46,1476 +47,613 @@
namespace transformer_engine { namespace transformer_engine {
namespace fused_attn { namespace fused_attn {
static cudnn_frontend::Tensor
createScale(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
const cudnn_frontend::Tensor& sTensor,
std::vector<cudnn_frontend::Operation>* ops) {
// scale
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
int64_t s_dim[4] = {b, h, s_q, s_kv};
int64_t s_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
auto scaleTensor = tensor_create(
tensorType, S_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
auto sScaleTensor = tensor_create(
tensorType, VIRTUAL_ID + 2000, s_dim,
s_stride, true, false); // is virtual
// Define the scale descriptor
auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a scale node
auto scale_op = binary_pw_op_create(sTensor, scaleTensor, sScaleTensor, scaleDesc);
ops->push_back(std::move(scale_op));
return sScaleTensor;
}
static cudnn_frontend::Tensor
createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
bool padding_aware, NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops) {
// Creates the necessary tensor descriptors
int64_t q_dim[4] = {b, h, s_q, d};
int64_t q_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
int64_t k_dim[4] = {b, h, d, s_kv};
int64_t k_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, k_stride, layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose);
int64_t s_dim[4] = {b, h, s_q, s_kv};
int64_t s_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false);
auto kTransposeTensor = tensor_create(
tensorType, K_ID, k_dim, k_stride, false, false); // is virtual
// first GEMM output
auto sTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, s_dim, s_stride, true, false); // is virtual
// Define the matmul 1 desc
auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
auto seqlenQTensor = tensor_create(
CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
auto seqlenKTensor = tensor_create(
CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
// Create a matmul 1 node
auto&& matmul_op_builder =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
matmul_op_builder.setaMatDesc(qTensor)
.setbMatDesc(kTransposeTensor)
.setcMatDesc(sTensor)
.setmatmulDesc(matmul_1_Desc);
if (padding_aware) {
matmul_op_builder.setmOverrideDesc(seqlenQTensor).setnOverrideDesc(seqlenKTensor);
}
auto matmul_op1 = matmul_op_builder.build();
ops->push_back(std::move(matmul_op1));
return sTensor;
}
static cudnn_frontend::Tensor
createPaddingMask(int64_t b,
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
NVTE_QKV_Layout layout,
cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& prevBlockOutputTensor) {
CUDNN_FRONTEND_UNUSED(d);
CUDNN_FRONTEND_UNUSED(layout);
CUDNN_FRONTEND_UNUSED(tensorType);
NVTE_CHECK(ops->size() != 0, "Padding Mask constructed incorrectly as the first one");
// subtraction output
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t maskVal_dim[4] = {1, 1, 1, 1};
int64_t maskVal_stride[4] = {1, 1, 1, 1};
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
// mask value to put in the masked pixels
auto maskValTensor = tensor_create(
CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim, maskVal_stride, false, true);
auto seqlenQTensor = tensor_create(
CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
auto seqlenKTensor = tensor_create(
CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
// gen index row output
auto rowIndexTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 300, afterBMM1_dim, afterBMM1_stride, true, false);
// gen index column output
auto columnIndexTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 301, afterBMM1_dim, afterBMM1_stride, true, false);
// less than row output
auto lessThanRowTensor = tensor_create(
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 302, afterBMM1_dim, afterBMM1_stride, true, false);
// less than column output
auto lessThanColTensor = tensor_create(
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 303, afterBMM1_dim, afterBMM1_stride, true, false);
// padding mask (lessthanRow && lessthanCol)
auto paddingMaskTensor = tensor_create(
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 304, afterBMM1_dim, afterBMM1_stride, true, false);
// output after masking
auto maskOutputTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 305, afterBMM1_dim, afterBMM1_stride, true, false);
// Define the gen index for row descriptor
auto genIndexRowDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setAxis(2)
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a gen index Node.
auto genIndexRow_op = unary_pw_op_create(
prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc);
// Define the gen index for row descriptor
auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setAxis(3)
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a gen index Node.
auto genIndexColumn_op = unary_pw_op_create(
prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc);
// Define the less than comparison for row descriptor
auto lessThanRowDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT);
// Create a less than comparison for row Node.
auto lessThanRow_op = binary_pw_op_create(
rowIndexTensor, seqlenQTensor, lessThanRowTensor, lessThanRowDesc);
// Define the less than comparison for column descriptor
auto lessThanColDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT);
// Create a less than comparison for col Node.
auto lessThanCol_op = binary_pw_op_create(
columnIndexTensor, seqlenKTensor, lessThanColTensor, lessThanColDesc);
// Define the less than comparison for column descriptor
auto paddingMaskAndDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_LOGICAL_AND);
// Create a and node for combining lessThanRow and lessThanCol
auto paddingMaskAnd_op = binary_pw_op_create(
lessThanRowTensor, lessThanColTensor, paddingMaskTensor, paddingMaskAndDesc);
/////////////////// Apply the mask //////////////////////////
// Define the binary select to perform masking descriptor
auto maskDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT);
// Create a binary select Node.
auto mask_op = ternary_pw_op_create(
prevBlockOutputTensor, maskValTensor, paddingMaskTensor, maskOutputTensor, maskDesc);
ops->push_back(std::move(genIndexRow_op));
ops->push_back(std::move(genIndexColumn_op));
ops->push_back(std::move(lessThanRow_op));
ops->push_back(std::move(lessThanCol_op));
ops->push_back(std::move(paddingMaskAnd_op));
ops->push_back(std::move(mask_op));
return maskOutputTensor;
}
static cudnn_frontend::Tensor
createCausalMask(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& prevBlockOutputTensor) {
CUDNN_FRONTEND_UNUSED(d);
CUDNN_FRONTEND_UNUSED(layout);
CUDNN_FRONTEND_UNUSED(tensorType);
NVTE_CHECK(ops->size() != 0, "Padding Mask constructed incorrectly as the first one");
// subtraction output
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t maskVal_dim[4] = {1, 1, 1, 1};
int64_t maskVal_stride[4] = {1, 1, 1, 1};
// mask value to put in the masked pixels
auto maskValTensor = tensor_create(
CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim,
maskVal_stride, false, true); // is by value
// gen index row output
auto rowIndexTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 100, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// gen index column output
auto columnIndexTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 101, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// create causal mask (row >= col)
auto causalMaskTensor = tensor_create(
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 106, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// output after masking
auto maskOutputTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 107, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// Define the gen index for row descriptor
auto genIndexRowDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setAxis(2)
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a gen index node
auto genIndexRow_op = unary_pw_op_create(
prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc);
// Define the gen index for row descriptor
auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setAxis(3)
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a gen index node
auto genIndexColumn_op = unary_pw_op_create(
prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc);
// Define the greater than equal to comparison descriptor
auto rowGreaterColDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_CMP_GE);
// Create a greater than equal to node
auto rowGreaterCol_op = binary_pw_op_create(
rowIndexTensor, columnIndexTensor, causalMaskTensor, rowGreaterColDesc);
// Define the binary select to perform masking descriptor
auto maskDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT);
// Create a binary select node
auto mask_op = ternary_pw_op_create(
prevBlockOutputTensor, maskValTensor,
causalMaskTensor, maskOutputTensor, maskDesc);
ops->push_back(std::move(genIndexRow_op));
ops->push_back(std::move(genIndexColumn_op));
ops->push_back(std::move(rowGreaterCol_op));
ops->push_back(std::move(mask_op));
return maskOutputTensor;
}
static cudnn_frontend::Tensor
createSoftmaxForward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, bool isTraining,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& sAfterMaskTensor) {
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t afterReduction_dim[4] = {b, h, s_q, 1};
int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1};
// max (x)
auto afterMaxReductionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 150, afterReduction_dim,
afterReduction_stride, true, false); // is virtual
// x - max(x)
auto afterSubtractionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 151, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// e^(x - max(x))
auto afterExponentTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 152, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual;
// sum (e^(x - max(x)))
auto afterAddReductionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 153, afterReduction_dim,
afterReduction_stride, true, false); // is virtual
// log (sum (e^(x - max(x))))
auto afterLogLTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 154, afterReduction_dim,
afterReduction_stride, true, false);
// M + log (sum (e^(x - max(x))))
auto softmaxStatsTensor = tensor_create(
CUDNN_DATA_FLOAT, S_STATS_ID, afterReduction_dim,
afterReduction_stride, !isTraining, false);
// not virtual if training is true, virtual if training is false
// divide (e/ sum(e))
auto afterSoftmaxTensor = cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setId(VIRTUAL_ID + 156)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual(true)
.setByValue(false)
.setReorderType(
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)
.build();
// Define the reduction descriptor
auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_MAX)
.build();
// Create a reduction max node
auto reductionMax_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(sAfterMaskTensor)
.setyDesc(afterMaxReductionTensor)
.setreductionDesc(reductionMaxDesc)
.build();
// Define the subtract descriptor
auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB);
// Create a subtract node
auto subtract_op = binary_pw_op_create(
sAfterMaskTensor, afterMaxReductionTensor,
afterSubtractionTensor, subtractDesc);
// Define the exponent descriptor
auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP);
// Create a exponent node
auto exponent_op = unary_pw_op_create(
afterSubtractionTensor, afterExponentTensor, exponentDesc);
// Define the reduction descriptor
auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
// Create a reduction add node
auto reductionAdd_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(afterExponentTensor)
.setyDesc(afterAddReductionTensor)
.setreductionDesc(reductionAddDesc)
.build();
// Create log descriptor
auto logDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_LOG);
// Create log node
auto log_op = unary_pw_op_create(afterAddReductionTensor, afterLogLTensor, logDesc);
// Create add descriptor
auto addDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ADD);
// Create add node
auto add_op = binary_pw_op_create(
afterMaxReductionTensor, afterLogLTensor,
softmaxStatsTensor, addDesc);
// Define the division descriptor
auto divisionDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_DIV);
// Create a subtract node
auto division_op = binary_pw_op_create(
afterExponentTensor, afterAddReductionTensor,
afterSoftmaxTensor, divisionDesc);
ops->push_back(std::move(reductionMax_op));
ops->push_back(std::move(subtract_op));
ops->push_back(std::move(exponent_op));
ops->push_back(std::move(reductionAdd_op));
ops->push_back(std::move(log_op));
ops->push_back(std::move(add_op));
ops->push_back(std::move(division_op));
return afterSoftmaxTensor;
}
static cudnn_frontend::Tensor
createDropoutForward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
double probability, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& afterSoftmaxTensor) {
CUDNN_FRONTEND_UNUSED(d);
NVTE_CHECK(ops->size() != 0, "Dropout DAG constructed incorrectly as the first one");
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
auto dropoutSeed = tensor_create(
CUDNN_DATA_INT64, D_SEED_ID, scale_dim,
scale_stride, false, false); // not virtual
auto dropoutOffset = tensor_create(
CUDNN_DATA_INT64, D_OFFSET_ID, scale_dim,
scale_stride, false, false); // not virtual
// mask for the dropout
auto dropoutMaskTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 200, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// after dropout tensor
auto afterDropoutTensor = cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setId(VIRTUAL_ID + 201)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(tensorType)
.setVirtual(true)
.setByValue(false)
.setReorderType(
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)
.build();
// scale after dropout
auto scaleDropoutTensor = tensor_create(
CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
// after Scale
auto afterScaleTensor = tensor_create(
tensorType, VIRTUAL_ID + 202, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// Define the reduction descriptor
auto rngDesc = cudnn_frontend::RngDescBuilder()
.setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI)
.setBernoulliDistProbability(1.0 - probability)
.build();
// Create a rng node
auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)
.setyDesc(dropoutMaskTensor)
.setSeedDesc(dropoutSeed)
.setOffsetDesc(dropoutOffset)
.setRngDesc(rngDesc)
.build();
// Define the multiply mask descriptor
auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply mask node
auto maskMul_op = binary_pw_op_create(
afterSoftmaxTensor, dropoutMaskTensor,
afterDropoutTensor, maskMulDesc);
// Define the multiply scale descriptor
auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply scale node
auto scaleMul_op = binary_pw_op_create(
afterDropoutTensor, scaleDropoutTensor,
afterScaleTensor, scaleMulDesc);
ops->push_back(std::move(rng_op));
ops->push_back(std::move(maskMul_op));
ops->push_back(std::move(scaleMul_op));
return afterScaleTensor;
}
static cudnn_frontend::Tensor
createDropoutBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
double probability, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& afterSoftmaxTensor,
const cudnn_frontend::Tensor& dropoutMaskTensor) {
CUDNN_FRONTEND_UNUSED(d);
NVTE_CHECK(ops->size() != 0, "Dropout DAG constructed incorrectly as the first one");
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
auto dropoutSeed = tensor_create(
CUDNN_DATA_INT64, D_SEED_ID, scale_dim,
scale_stride, false, false); // not virtual
auto dropoutOffset = tensor_create(
CUDNN_DATA_INT64, D_OFFSET_ID, scale_dim,
scale_stride, false, false); // not virtual
// after dropout tensor
auto afterDropoutTensor = cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setId(VIRTUAL_ID + 201)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(tensorType)
.setVirtual(true)
.setByValue(false)
.setReorderType(
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)
.build();
// scale after dropout
auto scaleDropoutTensor = tensor_create(
CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
// after Scale
auto afterScaleTensor = tensor_create(
tensorType, VIRTUAL_ID + 202, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// Define the reduction descriptor
auto rngDesc = cudnn_frontend::RngDescBuilder()
.setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI)
.setBernoulliDistProbability(1.0 - probability)
.build();
// Create a rng node
auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)
.setyDesc(dropoutMaskTensor)
.setSeedDesc(dropoutSeed)
.setOffsetDesc(dropoutOffset)
.setRngDesc(rngDesc)
.build();
// Define the multiply mask descriptor
auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply mask node
auto maskMul_op = binary_pw_op_create(
afterSoftmaxTensor, dropoutMaskTensor,
afterDropoutTensor, maskMulDesc);
// Define the multiply scale descriptor
auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply scale node
auto scaleMul_op = binary_pw_op_create(
afterDropoutTensor, scaleDropoutTensor,
afterScaleTensor, scaleMulDesc);
ops->push_back(std::move(rng_op));
ops->push_back(std::move(maskMul_op));
ops->push_back(std::move(scaleMul_op));
return afterScaleTensor;
}
static void
createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
bool padding_aware, NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
cudnn_frontend::Tensor const &afterScaleDropoutTensor) {
NVTE_CHECK(ops->size() != 0, "BMM2 op constructed incorrectly as the first one");
int64_t v_dim[4] = {b, h, s_kv, d};
int64_t v_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
int64_t o_dim[4] = {b, h, s_q, d};
int64_t o_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
auto seqlenQTensor = tensor_create(
CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
auto seqlenKTensor = tensor_create(
CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false);
// second GEMM output
auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false);
// Define the matmul 2 desc
auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
// Create a matmul 2 node
auto&& matmul_op_builder =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
matmul_op_builder.setaMatDesc(afterScaleDropoutTensor)
.setbMatDesc(vTensor)
.setcMatDesc(oTensor)
.setmatmulDesc(matmul_2_Desc);
if (padding_aware) {
matmul_op_builder.setmOverrideDesc(seqlenQTensor).setkOverrideDesc(seqlenKTensor);
}
auto matmul_op2 = matmul_op_builder.build();
ops->push_back(std::move(matmul_op2));
}
void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d,
bool is_training, float scaling_factor, float dropout_probability, bool is_training, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_Mask_Type mask_type, NVTE_QKV_Layout layout,
void *devPtrQ, void *devPtrK, void *devPtrV, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
void *devPtrSoftmaxStats, void *devPtrO, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias,
void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV, void *devPtrSoftmaxStats, void *devPtrO,
void* devPtrDropoutSeed, void* devPtrDropoutOffset, void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnnDataType_t tensorType, void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV,
void *workspace, size_t *workspace_size, cudnn_frontend::DataType_t tensorType,
cudaStream_t stream, cudnnHandle_t handle) { void *workspace, size_t *workspace_size,
try { cudaStream_t stream, cudnnHandle_t handle) {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
if (!is_training) { bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
dropout_probability = 0.0f; bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
} bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
// also known as variable_sequence_length bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
bool padding_aware = (mask_type == NVTE_PADDING_MASK) || || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
(mask_type == NVTE_PADDING_CAUSAL_MASK); bool is_dropout = (is_training && dropout_probability != 0.0f);
FADescriptor descriptor{b, h, try {
s_q, s_kv, FADescriptor_v1 descriptor{b, h,
d, scaling_factor, hg, s_q,
is_training, dropout_probability, s_kv, d,
layout, NVTE_Bias_Type::NVTE_NO_BIAS, scaling_factor, is_training,
mask_type, tensorType, dropout_probability, layout,
false}; bias_type, mask_type,
tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static thread_local CacheType fmha_fprop_cache; namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
std::shared_ptr<fe::graph::Tensor_attributes>, // Q
std::shared_ptr<fe::graph::Tensor_attributes>, // K
std::shared_ptr<fe::graph::Tensor_attributes>, // V
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes>, // O
std::shared_ptr<fe::graph::Tensor_attributes>, // Stats
std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // dropout_seed
std::shared_ptr<fe::graph::Tensor_attributes> >; // dropout_offset
using CacheType = std::map<FADescriptor_v1, graph_and_tensors>;
static thread_local CacheType sdpa_flash_f16_fprop_cache;
// Get plan from cache if cache is available, otherwise create one // Get plan from cache if cache is available, otherwise create one
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor)
-> graph_and_tensors {
// if hit, return // if hit, return
auto it = cache.find(descriptor); auto it = cache.find(descriptor);
if (it != cache.end()) { if (it != cache.end()) {
auto plan = it->second; auto graph = it->second;
return plan; return graph;
} }
// otherwise, build the op_graph and the plan. Then update cache // otherwise, build the op_graph and the plan. Then update cache
std::vector<cudnn_frontend::Operation const*> all_ops; auto mha_graph = std::make_shared<fe::graph::Graph>();
std::vector<cudnn_frontend::Operation> ops; mha_graph->set_io_data_type(tensorType)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
// Q * K^T .set_compute_data_type(fe::DataType_t::FLOAT);
auto sTensor = createQKBMM(
b, h, s_q, s_kv, d, padding_aware, layout, tensorType, &ops); std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv;
// Q * K^T * bmmScale std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
auto sScaleTensor = createScale(
b, h, s_q, s_kv, d, layout, CUDNN_DATA_FLOAT, sTensor, &ops); std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
auto& sAfterMaskTensor = sScaleTensor; std::vector<int64_t> v_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(),
if (mask_type == NVTE_CAUSAL_MASK || mask_type == NVTE_PADDING_CAUSAL_MASK) { layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
sAfterMaskTensor = createCausalMask( generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(),
b, h, s_q, s_kv, d, layout, tensorType, &ops, sScaleTensor); layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_stride));
K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride));
V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride));
attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
fe::graph::Scaled_dot_product_flash_attention_attributes
scaled_dot_product_flash_attention_options;
scaled_dot_product_flash_attention_options =
fe::graph::Scaled_dot_product_flash_attention_attributes()
.set_name("flash_attention")
.set_is_inference(!is_training)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
scaled_dot_product_flash_attention_options.set_alibi_mask(is_alibi);
if (is_bias) {
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({1, h, s_q, s_kv})
.set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
scaled_dot_product_flash_attention_options.set_bias(bias);
} }
if (padding_aware) { if (is_padding) {
sAfterMaskTensor = createPaddingMask( seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
b, h, s_q, s_kv, d, layout, tensorType, &ops, sAfterMaskTensor); .set_name("seq_q")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_kv")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
scaled_dot_product_flash_attention_options.set_padding_mask(is_padding)
.set_seq_len_q(seq_q)
.set_seq_len_kv(seq_kv);
} }
NVTE_CHECK(dropout_probability != 1.0f, if (is_dropout) {
"Dropout probability cannot be 1.0"); dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seed")
auto softmax_output = createSoftmaxForward( .set_dim({1, 1, 1, 1})
b, h, s_q, s_kv, is_training, &ops, sAfterMaskTensor); .set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT64));
dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Offset")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT64));
scaled_dot_product_flash_attention_options.set_dropout(
dropout_probability, dropout_seed, dropout_offset);
}
// Dropout(softmax) auto [O, Stats] = mha_graph->scaled_dot_product_flash_attention(
auto dropout_output = createDropoutForward( Q, K, V, scaled_dot_product_flash_attention_options);
b, h, s_q, s_kv, d,
dropout_probability, tensorType, &ops, softmax_output);
createSVBMM(b, h, s_q, s_kv, d, padding_aware,
layout, tensorType, &ops, dropout_output);
for (unsigned int i = 0; i < ops.size(); i++) { std::vector<int64_t> o_stride(4);
all_ops.push_back(&ops[i]); generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(),
} layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride);
// Create an Operation Graph if (is_training) {
auto opGraph = cudnn_frontend::OperationGraphBuilder() Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
.setHandle(handle) .set_dim({b, h, s_q, 1})
.setOperationGraph(all_ops.size(), all_ops.data()) .set_stride({h * s_q, s_q, 1, 1});
.build();
cudnn_frontend::EngineConfigList filtered_configs;
auto statuses = cudnn_frontend::get_heuristics_list<1>(
{"heuristics_instant"}, opGraph, allowAllConfig,
filtered_configs, true);
if (filtered_configs.size() == 0) {
cudnn_frontend::set_error_and_throw_exception(
nullptr,
CUDNN_STATUS_NOT_SUPPORTED,
"run_mha_fprop: No config returned by the heuristics");
} }
auto plan = cudnn_frontend::ExecutionPlanBuilder() std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q
.setHandle(handle) std::shared_ptr<fe::graph::Tensor_attributes>, // K
.setEngineConfig(filtered_configs[0], opGraph.getTag()) std::shared_ptr<fe::graph::Tensor_attributes>, // V
.build(); std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes> > // O
cache.insert({descriptor, plan}); key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O);
return plan; auto Stats_tuple = is_training ? std::make_tuple(Stats) : std::make_tuple(nullptr);
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto padding_tuple = is_padding ?
std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto dropout_tuple = is_dropout ?
std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr);
auto return_empty_tuple = std::tuple_cat(
std::make_tuple(nullptr), key_tensors_tuple,
Stats_tuple, bias_tuple, padding_tuple, dropout_tuple);
mha_graph->validate();
mha_graph->build_operation_graph(handle);
mha_graph->create_execution_plans({fe::HeurMode_t::A});
mha_graph->check_support(handle);
mha_graph->build_plans(handle);
auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple,
Stats_tuple, bias_tuple, padding_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
}; };
auto plan = get_plan(fmha_fprop_cache, descriptor); auto [mha_graph, Q, K, V, attn_scale, O, Stats,
bias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph(
sdpa_flash_f16_fprop_cache, descriptor);
auto plan_workspace_size = plan.getWorkspaceSize(); auto plan_workspace_size = mha_graph->get_workspace_size();
// Exit to request upper level API to allocate memory if needed // Exit to request upper level API to allocate memory if needed
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
if (workspace == nullptr) { if (workspace == nullptr) {
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
*workspace_size = plan_workspace_size + actual_seqlen_workspace_size; *workspace_size = plan_workspace_size + actual_seqlen_workspace_size;
return; return;
} }
// Prepare actual seqlen // Build variant pack
constexpr size_t nthreads_per_block = 128; std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; {Q, devPtrQ},
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size; {K, devPtrK},
void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t); {V, devPtrV},
{attn_scale, &scaling_factor},
{O, devPtrO}};
if (padding_aware) { if (is_training) {
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>( variant_pack[Stats] = devPtrSoftmaxStats;
b, static_cast<const int32_t *>(devPtrCuSeqlenQ),
static_cast<const int32_t *>(devPtrCuSeqlenKV),
static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenK));
NVTE_CHECK_CUDA(cudaGetLastError());
} }
std::set<std::pair<uint64_t, void*>> data_ptrs; if (is_bias) {
// Add all the data pointers to be used in the variant pack variant_pack[bias] = devPtrBias;
float negInfinity = -1.0E+30f;
float scale_dropout = 1.0f/(1.0f - dropout_probability);
data_ptrs.insert(std::pair<uint64_t, void*>(Q_ID, devPtrQ));
data_ptrs.insert(std::pair<uint64_t, void*>(K_ID, devPtrK));
data_ptrs.insert(std::pair<uint64_t, void*>(V_ID, devPtrV));
data_ptrs.insert(std::pair<uint64_t, void*>(MASK_VAL_ID, &negInfinity));
data_ptrs.insert(std::pair<uint64_t, void*>(S_CONST_ID, &scaling_factor));
data_ptrs.insert(std::pair<uint64_t, void*>(O_ID, devPtrO));
data_ptrs.insert(std::pair<uint64_t, void*>(D_SEED_ID, devPtrDropoutSeed));
data_ptrs.insert(std::pair<uint64_t, void*>(D_OFFSET_ID, devPtrDropoutOffset));
data_ptrs.insert(std::pair<uint64_t, void*>(D_CONST_ID, &scale_dropout));
if (padding_aware) {
data_ptrs.insert(std::pair<uint64_t, void*>(Q_SEQLEN_ID, devActualSeqlenQ));
data_ptrs.insert(std::pair<uint64_t, void*>(K_SEQLEN_ID, devActualSeqlenK));
} }
// If training mode, we write out softmax stats if (is_padding) {
if (is_training) { constexpr size_t nthreads_per_block = 128;
data_ptrs.insert(std::pair<uint64_t, void*>(S_STATS_ID, devPtrSoftmaxStats)); const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
static_cast<const int32_t *>(devPtrCuSeqlensKV),
static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenKV));
variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV;
} }
auto variantPack = cudnn_frontend::VariantPackBuilder() if (is_dropout) {
.setWorkspacePointer(workspace) variant_pack[dropout_seed] = devPtrDropoutSeed;
.setDataPointers(data_ptrs) variant_pack[dropout_offset] = devPtrDropoutOffset;
.build(); }
NVTE_CHECK_CUDNN( mha_graph->execute(handle, variant_pack, workspace);
cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
} catch (cudnn_frontend::cudnnException &e) { } catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what()); NVTE_ERROR(e.what());
} }
} }
void fused_attn_arbitrary_seqlen_bwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrKTranspose, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
void* devPtrVTranspose, void* devPtrO, void* devPtrSoftmaxStats, void* devPtrQ, void* devPtrKTranspose, void* devPtrVTranspose,
void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, void* devPtrO, void* devPtrSoftmaxStats, void* devPtrBias,
void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, void* devPtrdBias,
void* devPtrDropoutSeed, void* devPtrDropoutOffset, void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnnDataType_t tensorType, void *workspace, size_t *workspace_size, void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV,
cudaStream_t stream, cudnnHandle_t handle, bool use_workspace_opt) { cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
try { cudaStream_t stream, cudnnHandle_t handle) {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
// also known as variable_sequence_length bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool padding_aware = (mask_type == NVTE_PADDING_MASK) || bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
(mask_type == NVTE_PADDING_CAUSAL_MASK); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f);
FADescriptor descriptor{b, h, try {
s_q, s_kv, FADescriptor_v1 descriptor{b, h,
d, scaling_factor, hg, s_q,
true, dropout_probability, s_kv, d,
layout, NVTE_Bias_Type::NVTE_NO_BIAS, scaling_factor, true,
mask_type, tensorType, dropout_probability, layout,
use_workspace_opt}; bias_type, mask_type,
tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static thread_local CacheType fmha_bprop_cache; namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
std::shared_ptr<fe::graph::Tensor_attributes>, // q
std::shared_ptr<fe::graph::Tensor_attributes>, // k
std::shared_ptr<fe::graph::Tensor_attributes>, // v
std::shared_ptr<fe::graph::Tensor_attributes>, // o
std::shared_ptr<fe::graph::Tensor_attributes>, // dO
std::shared_ptr<fe::graph::Tensor_attributes>, // stats
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes>, // dQ
std::shared_ptr<fe::graph::Tensor_attributes>, // dK
std::shared_ptr<fe::graph::Tensor_attributes>, // dV
std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // dBias
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // dropout_seed
std::shared_ptr<fe::graph::Tensor_attributes> >; // dropout_offset
using CacheType = std::map<FADescriptor_v1, graph_and_tensors>;
static thread_local CacheType sdpa_flash_f16_bprop_cache;
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { // Get plan from cache if cache is available, otherwise create one
auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor)
-> graph_and_tensors {
// if hit, return
auto it = cache.find(descriptor); auto it = cache.find(descriptor);
if (it != cache.end()) { if (it != cache.end()) {
return it->second; auto graph = it->second;
} return graph;
std::vector<cudnn_frontend::Operation const*> all_ops;
std::vector<cudnn_frontend::Operation> ops;
// Creates the necessary tensor descriptors
int64_t q_dim[4] = {b, h, s_q, d};
int64_t q_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, q_stride,
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
int64_t k_transpose_dim[4] = {b, h, d, s_kv};
int64_t k_transpose_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, k_transpose_stride,
layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose);
int64_t v_transpose_dim[4] = {b, h, d, s_kv};
int64_t v_transpose_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, v_transpose_stride,
layout, NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose);
int64_t p_dim[4] = {b, h, s_q, s_kv};
int64_t p_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, p_stride,
layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
int64_t p_transpose_dim[4] = {b, h, s_kv, s_q};
int64_t p_transpose_stride[4];
p_transpose_stride[0] = p_stride[0];
p_transpose_stride[1] = p_stride[1];
p_transpose_stride[2] = p_stride[3];
p_transpose_stride[3] = p_stride[2];
int64_t o_dim[4] = {b, h, s_q, d};
int64_t o_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, o_stride,
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
int64_t dqAccum_dim[4] = {b, h, s_q, d};
int64_t dqAccum_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dqAccum_stride,
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
auto seqlenQTensor = tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim,
seqlen_stride, false, false);
auto seqlenKTensor = tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim,
seqlen_stride, false, false);
/*******************************************************************************
* Dot product dO * O */
// output and gradient of the output
auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false);
auto dOTensor = tensor_create(tensorType, dO_ID, o_dim, o_stride, false, false);
auto dotProductTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID, o_dim,
o_stride, true, false); // is virtual
// Create pointwise mul
auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// do * O
auto dotProductOp = binary_pw_op_create(
dOTensor, oTensor, dotProductTensor, multiplyDesc);
ops.push_back(std::move(dotProductOp));
/*******************************************************************************
* Reduction(dO * O) */
int64_t reduction_dim[4] = {b, h, s_q, 1};
int64_t reduction_stride[4] = {h * s_q, s_q, 1, 1};
// reduction(dO * O)
auto afterReductionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, reduction_dim,
reduction_stride, true, false); // is virtual
auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
// Create a reduction add node
auto reductionAdd_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(dotProductTensor)
.setyDesc(afterReductionTensor)
.setreductionDesc(reductionAddDesc)
.build();
ops.push_back(std::move(reductionAdd_op));
/*******************************************************************************
* reduction(dO * O) * scale prob -> softmaxSum */
auto softmaxSumTensor = tensor_create(
CUDNN_DATA_FLOAT, S_SUM_ID, reduction_dim,
reduction_stride, false, false); // not virtual
auto scaleProbTensor = tensor_create(
CUDNN_DATA_FLOAT, SCALE_PROB, scale_dim,
scale_stride, false, true); // not virtual
auto softmaxSumOp = binary_pw_op_create(
afterReductionTensor, scaleProbTensor,
softmaxSumTensor, multiplyDesc);
ops.push_back(std::move(softmaxSumOp));
/*******************************************************************************
* Q @ K.T -> P */
// Inputs from fprop
auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false);
auto kTransposeTensor = tensor_create(
tensorType, K_ID, k_transpose_dim,
k_transpose_stride, false, false);
auto pTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 2, p_dim,
p_stride, true, false); // is virtual
// matmul to calculate dvTensor
auto matmul_0_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
auto&& matmul_op_builder =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
matmul_op_builder.setaMatDesc(qTensor)
.setbMatDesc(kTransposeTensor)
.setcMatDesc(pTensor)
.setmatmulDesc(matmul_0_Desc);
if (padding_aware) {
matmul_op_builder.setmOverrideDesc(seqlenQTensor).setnOverrideDesc(seqlenKTensor);
}
auto matmul_op0 = matmul_op_builder.build();
ops.push_back(std::move(matmul_op0));
/*******************************************************************************
* P * bmmScale -> pAfterScale */
auto bmmScaleTensor = tensor_create(
CUDNN_DATA_FLOAT, S_CONST_ID, scale_dim,
scale_stride, false, true); // not virtual and by value
auto pAfterScaleTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 2000, p_dim,
p_stride, true, false); // virtual
auto scaleOp = binary_pw_op_create(
pTensor, bmmScaleTensor, pAfterScaleTensor, multiplyDesc);
ops.push_back(std::move(scaleOp));
/*******************************************************************************
* Causal masking -> pAfterMaskTensor */
auto& pAfterMaskTensor = pAfterScaleTensor;
if (mask_type == NVTE_CAUSAL_MASK || mask_type == NVTE_PADDING_CAUSAL_MASK) {
pAfterMaskTensor = createCausalMask(
b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterScaleTensor);
}
if (padding_aware) {
pAfterMaskTensor = createPaddingMask(
b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterMaskTensor);
}
/*******************************************************************************
* pAfterMaskTensor - softmaxStats -> pAfterSubtract */
auto pAfterSubtractTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 3, p_dim,
p_stride, true, false); // is virtual
auto softmaxStatsTensor = tensor_create(
CUDNN_DATA_FLOAT, S_STATS_ID, reduction_dim,
reduction_stride, false, false); // not virtual
auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB);
auto subtract_op = binary_pw_op_create(
pAfterMaskTensor, softmaxStatsTensor,
pAfterSubtractTensor, subtractDesc);
ops.push_back(std::move(subtract_op));
/*******************************************************************************
* e^(pAfterSubtract) -> pAfterSoftmax */
auto pAfterSoftmaxTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 4, p_dim,
p_stride, true, false); // is virtual
auto expDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP);
auto exp_op = unary_pw_op_create(
pAfterSubtractTensor, pAfterSoftmaxTensor, expDesc);
ops.push_back(std::move(exp_op));
/*******************************************************************************
* Dropout -> afterScaleDropout */
auto dropoutMaskTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 5, p_dim,
p_stride, true, false); // is virtual
auto afterScaleDropoutTensor = createDropoutBackward(
b, h, s_q, s_kv, d, dropout_probability, tensorType,
&ops, pAfterSoftmaxTensor, dropoutMaskTensor);
/*******************************************************************************
* afterScaleDropout -> sTransposeTensor */
auto sTransposeTensor = tensor_create(
tensorType, VIRTUAL_ID + 6, p_transpose_dim,
p_transpose_stride, true, false); // is virtual
auto reshape_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(afterScaleDropoutTensor)
.setyDesc(sTransposeTensor)
.build();
ops.push_back(std::move(reshape_op));
// Outputs of bprop
int64_t dq_dim[4] = {b, h, s_q, d};
int64_t dq_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dq_stride,
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
int64_t dk_dim[4] = {b, h, s_kv, d};
int64_t dk_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dk_stride,
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
int64_t dv_dim[4] = {b, h, s_kv, d};
int64_t dv_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dv_stride,
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
// Outputs of backprop
auto dQTensor = tensor_create(tensorType, dQ_ID, dq_dim, dq_stride, false, false);
auto dKTensor = tensor_create(tensorType, dK_ID, dk_dim, dk_stride, false, false);
auto dVTensor = tensor_create(tensorType, dV_ID, dv_dim, dv_stride, false, false);
// not virtual
/*******************************************************************************
* sTransposeTensor @ dO -> dV */
auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
auto&& matmul_op1_builder =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
matmul_op1_builder.setaMatDesc(sTransposeTensor)
.setbMatDesc(dOTensor)
.setcMatDesc(dVTensor)
.setmatmulDesc(matmul_1_Desc);
if (padding_aware) {
matmul_op1_builder.setmOverrideDesc(seqlenKTensor).setkOverrideDesc(seqlenQTensor);
}
auto matmul_op1 = matmul_op1_builder.build();
ops.push_back(std::move(matmul_op1));
/*******************************************************************************
* dO @ V.T -> dS */
auto vTransposeTensor = tensor_create(
tensorType, V_ID, v_transpose_dim,
v_transpose_stride, false, false);
auto dSTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 7, p_dim,
p_stride, true, false); // is virtual
auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
auto&& matmul_op2_builder =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
matmul_op2_builder.setaMatDesc(dOTensor)
.setbMatDesc(vTransposeTensor)
.setcMatDesc(dSTensor)
.setmatmulDesc(matmul_2_Desc);
if (padding_aware) {
matmul_op2_builder.setmOverrideDesc(seqlenQTensor).setnOverrideDesc(seqlenKTensor);
}
auto matmul_op2 = matmul_op2_builder.build();
ops.push_back(std::move(matmul_op2));
/*******************************************************************************
* dS * dropoutMask -> dSAfterDropout */
auto dSAfterDropoutTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 8, p_dim,
p_stride, true, false); // is virtual
auto multiply_op = binary_pw_op_create(
dSTensor, dropoutMaskTensor,
dSAfterDropoutTensor, multiplyDesc);
ops.push_back(std::move(multiply_op));
/*******************************************************************************
* dSAfterDropout - softmaxSum -> dsAfterSubtract */
auto dsAfterSubtractTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 9, p_dim,
p_stride, true, false); // is virtual
auto subtract_op2 = binary_pw_op_create(
dSAfterDropoutTensor, softmaxSumTensor,
dsAfterSubtractTensor, subtractDesc);
ops.push_back(std::move(subtract_op2));
/*******************************************************************************
* dsAfterSubtract * afterSoftmax -> dP */
auto dPTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 10, p_dim,
p_stride, true, false); // is virtual
auto multiply_op2 = binary_pw_op_create(
dsAfterSubtractTensor, pAfterSoftmaxTensor,
dPTensor, multiplyDesc);
ops.push_back(std::move(multiply_op2));
/*******************************************************************************
* dP * scaleDropout -> dPAfterDropoutScale */
auto dPAfterDropoutScaleTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 11, p_dim,
p_stride, true, false); // is virtual
auto scaleDropoutTensor = tensor_create(
CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
auto multiply_op3 = binary_pw_op_create(
dPTensor, scaleDropoutTensor,
dPAfterDropoutScaleTensor, multiplyDesc);
ops.push_back(std::move(multiply_op3));
/*******************************************************************************
* dPAfterDropoutScale * bmmScale -> dPScaledTensor */
auto dPScaledTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 12, p_dim,
p_stride, true, false); // is virtual
auto multiply_op4 = binary_pw_op_create(
dPAfterDropoutScaleTensor, bmmScaleTensor,
dPScaledTensor, multiplyDesc);
ops.push_back(std::move(multiply_op4));
/*******************************************************************************
* K.T -> K */
int64_t kDim[4] = {b, h, s_kv, d};
int64_t kStride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, kStride,
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
auto kTensor = tensor_create(
tensorType, VIRTUAL_ID + 13, kDim,
kStride, true, false); // is virtual
auto reshape_op2 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(kTransposeTensor)
.setyDesc(kTensor)
.build();
ops.push_back(std::move(reshape_op2));
/*******************************************************************************
* dP @ K -> dqAccumTensor / dqTensor */
auto dqAccumTensor = cudnn_frontend::TensorBuilder()
.setDim(4, dqAccum_dim)
.setStride(4, dqAccum_stride)
.setId(dQ_ACCUM_ID)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual(false)
.setByValue(false)
.setReorderType(
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)
.build();
auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
auto&& matmul_op3_builder =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
matmul_op3_builder.setaMatDesc(dPScaledTensor)
.setbMatDesc(kTensor)
.setmatmulDesc(matmul_3_Desc);
if (use_workspace_opt) {
matmul_op3_builder.setcMatDesc(dQTensor);
} else {
matmul_op3_builder.setcMatDesc(dqAccumTensor);
}
if (padding_aware) {
matmul_op3_builder.setmOverrideDesc(seqlenQTensor).setkOverrideDesc(seqlenKTensor);
} }
auto matmul_op3 = matmul_op3_builder.build(); // otherwise, build the op_graph and the plan. Then update cache
auto mha_graph = std::make_shared<fe::graph::Graph>();
ops.push_back(std::move(matmul_op3)); mha_graph->set_io_data_type(tensorType)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
/******************************************************************************* .set_compute_data_type(fe::DataType_t::FLOAT);
* dP.T @ Q -> dK */
std::shared_ptr<fe::graph::Tensor_attributes> q, k, v, o, dO, stats, attn_scale;
auto dPTransposeTensor = tensor_create( std::shared_ptr<fe::graph::Tensor_attributes> bias, dBias, seq_q, seq_kv;
CUDNN_DATA_FLOAT, VIRTUAL_ID + 14, p_transpose_dim, std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
p_transpose_stride, true, false); // is virtual
auto reshape_op3 = cudnn_frontend::OperationBuilder( std::vector<int64_t> q_stride(4);
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) std::vector<int64_t> k_stride(4);
.setxDesc(dPScaledTensor) std::vector<int64_t> v_stride(4);
.setyDesc(dPTransposeTensor) std::vector<int64_t> o_stride(4);
.build(); generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(),
ops.push_back(std::move(reshape_op3)); layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(),
auto matmul_4_Desc = cudnn_frontend::MatMulDescBuilder() layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
.setComputeType(CUDNN_DATA_FLOAT) generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(),
.setPaddingValue(0.0f) layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
.build(); generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
auto&& matmul_op4_builder = q = mha_graph->tensor(fe::graph::Tensor_attributes()
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR); .set_name("Q")
.set_dim({b, h, s_q, d})
matmul_op4_builder.setaMatDesc(dPTransposeTensor) .set_stride(q_stride));
.setbMatDesc(qTensor) k = mha_graph->tensor(fe::graph::Tensor_attributes()
.setcMatDesc(dKTensor) .set_name("K")
.setmatmulDesc(matmul_4_Desc); .set_dim({b, hg, s_kv, d})
.set_stride(k_stride));
if (padding_aware) { v = mha_graph->tensor(fe::graph::Tensor_attributes()
matmul_op4_builder.setmOverrideDesc(seqlenKTensor).setkOverrideDesc(seqlenQTensor); .set_name("V")
} .set_dim({b, hg, s_kv, d})
.set_stride(v_stride));
auto matmul_op4 = matmul_op4_builder.build(); o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
ops.push_back(std::move(matmul_op4)); .set_dim({b, h, s_q, d})
.set_stride(o_stride));
/******************************************************************************* dO = mha_graph->tensor(fe::graph::Tensor_attributes()
* dqAccumTensor @ identity -> dqTensor */ .set_name("dO")
.set_dim({b, h, s_q, d})
if (!use_workspace_opt) { .set_stride(o_stride));
auto identityDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_IDENTITY); stats = mha_graph->tensor(fe::graph::Tensor_attributes()
auto identity_op = unary_pw_op_create(dqAccumTensor, dQTensor, identityDesc); .set_name("stats")
ops.push_back(std::move(identity_op)); .set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
fe::graph::Scaled_dot_product_flash_attention_backward_attributes
scaled_dot_product_flash_attention_backward_options;
scaled_dot_product_flash_attention_backward_options =
fe::graph::Scaled_dot_product_flash_attention_backward_attributes()
.set_name("flash_attention_backward")
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
scaled_dot_product_flash_attention_backward_options.set_alibi_mask(is_alibi);
if (is_bias) {
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({1, h, s_q, s_kv})
.set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
dBias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dBias")
.set_dim({1, h, s_q, s_kv})
.set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
scaled_dot_product_flash_attention_backward_options.set_bias(bias);
scaled_dot_product_flash_attention_backward_options.set_dbias(dBias);
} }
for (unsigned int i = 0; i < ops.size(); i++) { if (is_padding) {
all_ops.push_back(&ops[i]); seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_q")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_kv")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
scaled_dot_product_flash_attention_backward_options.set_padding_mask(is_padding)
.set_seq_len_q(seq_q)
.set_seq_len_kv(seq_kv);
} }
// Create an Operation Graph if (is_dropout) {
auto opGraph = cudnn_frontend::OperationGraphBuilder() dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes()
.setHandle(handle) .set_name("Seed")
.setOperationGraph(all_ops.size(), all_ops.data()) .set_dim({1, 1, 1, 1})
.build(); .set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT64));
cudnn_frontend::EngineConfigList filtered_configs; dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
auto statuses = cudnn_frontend::get_heuristics_list<1>( .set_name("Offset")
{"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true); .set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
if (filtered_configs.size() == 0) { .set_data_type(fe::DataType_t::INT64));
cudnn_frontend::set_error_and_throw_exception( scaled_dot_product_flash_attention_backward_options.set_dropout(
nullptr, CUDNN_STATUS_NOT_SUPPORTED, dropout_probability, dropout_seed, dropout_offset);
"run_mha_bprop: No config returned by the heuristics");
} }
auto plan = cudnn_frontend::ExecutionPlanBuilder() auto [dQ, dK, dV] = mha_graph->scaled_dot_product_flash_attention_backward(
.setHandle(handle) q, k, v, o, dO, stats, scaled_dot_product_flash_attention_backward_options);
.setEngineConfig(filtered_configs[0], opGraph.getTag())
.build(); dQ->set_output(true)
.set_dim({b, h, s_q, d})
cache.insert({descriptor, plan}); .set_stride(q_stride);
return plan; dK->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride);
dV->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride);
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // q
std::shared_ptr<fe::graph::Tensor_attributes>, // k
std::shared_ptr<fe::graph::Tensor_attributes>, // v
std::shared_ptr<fe::graph::Tensor_attributes>, // o
std::shared_ptr<fe::graph::Tensor_attributes>, // dO
std::shared_ptr<fe::graph::Tensor_attributes>, // stats
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes>, // dQ
std::shared_ptr<fe::graph::Tensor_attributes>, // dK
std::shared_ptr<fe::graph::Tensor_attributes> > // dV
key_tensors_tuple = std::make_tuple(q, k, v, o, dO, stats, attn_scale, dQ, dK, dV);
auto bias_tuple = is_bias ?
std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr);
auto padding_tuple = is_padding ?
std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto dropout_tuple = is_dropout ?
std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr);
auto return_empty_tuple = std::tuple_cat(
std::make_tuple(nullptr), key_tensors_tuple,
bias_tuple, padding_tuple, dropout_tuple);
mha_graph->validate();
mha_graph->build_operation_graph(handle);
mha_graph->create_execution_plans({fe::HeurMode_t::A});
mha_graph->check_support(handle);
mha_graph->build_plans(handle);
auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple,
bias_tuple, padding_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
}; };
auto plan = get_plan(fmha_bprop_cache, descriptor); auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV,
bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph(
sdpa_flash_f16_bprop_cache, descriptor);
auto plan_workspace_size = plan.getWorkspaceSize(); auto plan_workspace_size = mha_graph->get_workspace_size();
// Exit to request upper level API to allocate memory if needed // Exit to request upper level API to allocate memory if needed
size_t softmaxSum_workspace_size = b * h * s_q * sizeof(float);
size_t dqAccum_workspace_size = use_workspace_opt ? 0 : b * s_q * h * d * sizeof(float);
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
if (workspace == nullptr) { if (workspace == nullptr) {
*workspace_size = plan_workspace_size + softmaxSum_workspace_size *workspace_size = plan_workspace_size + actual_seqlen_workspace_size;
+ dqAccum_workspace_size + actual_seqlen_workspace_size;
return; return;
} }
void *devPtrSoftmaxSum = static_cast<int8_t *>(workspace) + plan_workspace_size; // build variant pack
void *devPtrdQAccumulator = static_cast<int8_t *>(devPtrSoftmaxSum) std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
+ softmaxSum_workspace_size; {q, devPtrQ},
if (!use_workspace_opt) { {k, devPtrKTranspose},
NVTE_CHECK_CUDA(cudaMemsetAsync( {v, devPtrVTranspose},
devPtrdQAccumulator, 0, dqAccum_workspace_size, stream)); {o, devPtrO},
} {dO, devPtrdO},
{stats, devPtrSoftmaxStats},
{attn_scale, &scaling_factor},
{dQ, devPtrdQ},
{dK, devPtrdK},
{dV, devPtrdV},
};
constexpr size_t nthreads_per_block = 128; if (is_bias) {
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; variant_pack[bias] = devPtrBias;
void *devActualSeqlenQ = variant_pack[dBias] = devPtrdBias;
static_cast<int8_t *>(devPtrdQAccumulator) + dqAccum_workspace_size;
void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devPtrCuSeqlenQ),
static_cast<const int32_t *>(devPtrCuSeqlenKV),
static_cast<int32_t *>(devActualSeqlenQ), static_cast<int32_t *>(devActualSeqlenK));
NVTE_CHECK_CUDA(cudaGetLastError());
std::set<std::pair<uint64_t, void *>> data_ptrs;
// add all the data pointers to be used in the variant pack
float negInfinity = -1.0E+31f;
float scale_dropout = 1.0f/(1.0f - dropout_probability);
data_ptrs.insert(std::pair<uint64_t, void*>(dQ_ID, devPtrdQ));
if (!use_workspace_opt) {
data_ptrs.insert(std::pair<uint64_t, void*>(dQ_ACCUM_ID, devPtrdQAccumulator));
}
data_ptrs.insert(std::pair<uint64_t, void*>(dK_ID, devPtrdK));
data_ptrs.insert(std::pair<uint64_t, void*>(dV_ID, devPtrdV));
data_ptrs.insert(std::pair<uint64_t, void*>(Q_ID, devPtrQ));
data_ptrs.insert(std::pair<uint64_t, void*>(K_ID, devPtrKTranspose));
data_ptrs.insert(std::pair<uint64_t, void*>(V_ID, devPtrVTranspose));
data_ptrs.insert(std::pair<uint64_t, void*>(O_ID, devPtrO));
data_ptrs.insert(std::pair<uint64_t, void*>(dO_ID, devPtrdO));
data_ptrs.insert(std::pair<uint64_t, void*>(S_STATS_ID, devPtrSoftmaxStats));
data_ptrs.insert(std::pair<uint64_t, void*>(S_SUM_ID, devPtrSoftmaxSum));
data_ptrs.insert(std::pair<uint64_t, void*>(D_SEED_ID, devPtrDropoutSeed));
data_ptrs.insert(std::pair<uint64_t, void*>(D_OFFSET_ID, devPtrDropoutOffset));
data_ptrs.insert(std::pair<uint64_t, void*>(MASK_VAL_ID, &negInfinity));
if (padding_aware) {
data_ptrs.insert(std::pair<uint64_t, void *>(Q_SEQLEN_ID, devActualSeqlenQ));
data_ptrs.insert(std::pair<uint64_t, void *>(K_SEQLEN_ID, devActualSeqlenK));
} }
float scaleProb = 1.0f - dropout_probability; if (is_padding) {
data_ptrs.insert(std::pair<uint64_t, void*>(D_CONST_ID, &scale_dropout)); constexpr size_t nthreads_per_block = 128;
data_ptrs.insert(std::pair<uint64_t, void*>(S_CONST_ID, &scaling_factor)); const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
data_ptrs.insert(std::pair<uint64_t, void*>(SCALE_PROB, &scaleProb)); void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
static_cast<const int32_t *>(devPtrCuSeqlensKV),
static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenKV));
variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV;
}
auto variantPack = cudnn_frontend::VariantPackBuilder() if (is_dropout) {
.setWorkspacePointer(workspace) variant_pack[dropout_seed] = devPtrDropoutSeed;
.setDataPointers(data_ptrs) variant_pack[dropout_offset] = devPtrDropoutOffset;
.build(); }
NVTE_CHECK_CUDNN( mha_graph->execute(handle, variant_pack, workspace);
cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
} catch (cudnn_frontend::cudnnException &e) { } catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what()); NVTE_ERROR(e.what());
} }
} }
} // namespace fused_attn } // namespace fused_attn
using namespace transformer_engine::fused_attn; using namespace transformer_engine::fused_attn;
void fused_attn_arbitrary_seqlen_fwd_qkvpacked( void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, bool is_training, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
// QKV shape is [b, s, 3, h, d] const DType QKV_type = input_QKV->data.dtype;
void *devPtrQKV = input_QKV->data.dptr; void *devPtrQKV = input_QKV->data.dptr;
const auto stride = 2 * num_head * head_dim; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = 2 * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = 2 * head_dim;
}
void *devPtrQ = static_cast<void *>(devPtrQKV); void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride); void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride); void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void *devPtrBias = input_Bias->data.dptr;
void *devPtrO = output_O->data.dptr; void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr; void *devPtrS = nullptr;
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 2; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {1, num_attn_heads, max_seqlen, max_seqlen};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
}
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; devPtrS = output_S->data.dptr;
output_S->data.shape = {batch, num_head, max_seqlen, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr; output_rng_state->data.dptr = rng_state->data.dptr;
output_rng_state->data.shape = {2}; } else if (Aux_CTX_Tensors->size == 3) {
output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = devPtrBias;
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
} }
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr; void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>( void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_attn_heads,
is_training, attn_scale, p_dropout, qkv_layout, mask_type, max_seqlen, max_seqlen, head_dim,
devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, is_training, attn_scale, p_dropout, qkv_layout,
devPtrCuSeqlens, devPtrCuSeqlens, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type), devPtrCuSeqlens, devPtrCuSeqlens,
workspace->data.dptr, &workspace_size, stream, handle); get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size,
stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1532,29 +670,39 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -1532,29 +670,39 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
} }
} }
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_heads,
size_t head_dim, float attn_scale, float p_dropout, size_t max_seqlen, size_t head_dim, float attn_scale,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, Tensor *output_S, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *rng_state, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr; void *devPtrQKV = input_QKV->data.dptr;
auto stride = 2 * num_head * head_dim; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = 2 * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = 2 * head_dim;
}
void *devPtrQ = devPtrQKV; void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride); void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride); void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void* devPtrO = input_O->data.dptr; void* devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr; void *devPtrdO = input_dO->data.dptr;
void *devPtrBias = nullptr;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
}
void *devPtrdBias = output_dBias->data.dptr;
// dQKV shape is [b, s, 3, h, d]
void *devPtrdQKV = output_dQKV->data.dptr; void *devPtrdQKV = output_dQKV->data.dptr;
void *devPtrdQ = devPtrdQKV; void *devPtrdQ = devPtrdQKV;
void *devPtrdK = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + stride); void *devPtrdK = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + stride);
...@@ -1563,50 +711,208 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, ...@@ -1563,50 +711,208 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
void *devPtrSoftmaxStats = nullptr; void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr; devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr; void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>( void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
const auto qkv_type = input_QKV->data.dtype; const auto qkv_type = input_QKV->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
bool use_workspace_opt = false; fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_attn_heads,
#if (CUDNN_VERSION >= 8905) max_seqlen, max_seqlen, head_dim,
const int device_id = cuda::current_device(); attn_scale, p_dropout, qkv_layout,
const int sm_arch_ = cuda::sm_arch(device_id); bias_type, mask_type,
if (sm_arch_ >= 90) { devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
// quick estimate of dp workspace size devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
size_t max_seqlen_div_up_q = ((max_seqlen + 64 - 1) / 64) * 64; devPtrDropoutSeed, devPtrDropoutOffset,
size_t max_seqlen_div_up_kv = ((max_seqlen + 64 - 1) / 64) * 64; devPtrCuSeqlens, devPtrCuSeqlens,
size_t required_dp_workspace = get_cudnn_fe_dtype(qkv_type), workspace->data.dptr,
(batch * num_head * max_seqlen_div_up_q * max_seqlen_div_up_kv * 2 + 1048576 - 1) / 1048576; &workspace_size, stream, handle);
// default upper limit for dp workspace 256MB
size_t max_allowed_dp_workspace = 256; if (workspace_size > 0) {
if (required_dp_workspace <= max_allowed_dp_workspace) { if (workspace->data.dptr == nullptr) {
use_workspace_opt = true; workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
} }
use_workspace_opt = transformer_engine::getenv<bool>( } else if (workspace_size == 0) {
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt); workspace->data.shape = {1};
#if (CUDNN_VERSION < 8906) workspace->data.dtype = DType::kByte;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); return;
if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) } else {
|| (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) { NVTE_ERROR("Unexpected workspace_size.");
use_workspace_opt = false; }
}
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = 2 * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = 2 * head_dim;
}
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
void *devPtrBias = input_Bias->data.dptr;
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {1, num_attn_heads, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} }
#endif } else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = devPtrBias;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
} }
#endif
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, void* devPtrDropoutSeed = rng_state->data.dptr;
attn_scale, p_dropout, qkv_layout, mask_type, void* devPtrDropoutOffset = reinterpret_cast<void *>(
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrCuSeqlens, devPtrCuSeqlens, size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv,
head_dim, is_training, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV,
get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size,
stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = 2 * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = 2 * head_dim;
}
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
void* devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr;
void *devPtrBias = nullptr;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
}
void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdKV = output_dKV->data.dptr;
void *devPtrdK = devPtrdKV;
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdKV) + stride);
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv,
head_dim, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(qkv_type), workspace->data.dptr, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
&workspace_size, stream, handle, use_workspace_opt); get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1624,8 +930,8 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, ...@@ -1624,8 +930,8 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
} }
void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv, size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t num_head, size_t head_dim, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
...@@ -1640,22 +946,49 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1640,22 +946,49 @@ void fused_attn_arbitrary_seqlen_fwd(
void *devPtrV = input_V->data.dptr; void *devPtrV = input_V->data.dptr;
void *devPtrO = output_O->data.dptr; void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr; void *devPtrS = nullptr;
void *devPtrBias = input_Bias->data.dptr;
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 2; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {1, num_attn_heads, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
}
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; devPtrS = output_S->data.dptr;
output_S->data.shape = {batch, num_head, max_seqlen_q, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr; output_rng_state->data.dptr = rng_state->data.dptr;
output_rng_state->data.shape = {2}; } else if (Aux_CTX_Tensors->size == 3) {
output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = devPtrBias;
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
} }
...@@ -1664,18 +997,18 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1664,18 +997,18 @@ void fused_attn_arbitrary_seqlen_fwd(
void* devPtrDropoutOffset = reinterpret_cast<void *>( void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_head, max_seqlen_q, max_seqlen_kv, head_dim, fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_gqa_groups,
is_training, attn_scale, p_dropout, qkv_layout, mask_type, max_seqlen_q, max_seqlen_kv,
devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, head_dim, is_training, attn_scale, p_dropout, qkv_layout,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type), devPtrCuSeqlensQ, devPtrCuSeqlensKV,
workspace->data.dptr, &workspace_size, stream, handle); get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size,
stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1692,13 +1025,14 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1692,13 +1025,14 @@ void fused_attn_arbitrary_seqlen_fwd(
} }
} }
void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv, void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t num_head, size_t head_dim, float attn_scale, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O, const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, Tensor *output_S, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV,
Tensor *output_dBias, const Tensor *cu_seqlens_q, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_kv,
...@@ -1712,57 +1046,37 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t m ...@@ -1712,57 +1046,37 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t m
void *devPtrV = input_V->data.dptr; void *devPtrV = input_V->data.dptr;
void* devPtrO = input_O->data.dptr; void* devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr; void *devPtrdO = input_dO->data.dptr;
void *devPtrBias = nullptr;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
}
void *devPtrdQ = output_dQ->data.dptr; void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdK = output_dK->data.dptr; void *devPtrdK = output_dK->data.dptr;
void *devPtrdV = output_dV->data.dptr; void *devPtrdV = output_dV->data.dptr;
void *devPtrSoftmaxStats = nullptr; void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr; devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr; void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>( void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
size_t workspace_size = 0; size_t workspace_size = 0;
bool use_workspace_opt = false; fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_gqa_groups,
#if (CUDNN_VERSION >= 8905) max_seqlen_q, max_seqlen_kv,
const int device_id = cuda::current_device(); head_dim, attn_scale, p_dropout, qkv_layout,
const int sm_arch_ = cuda::sm_arch(device_id); bias_type, mask_type,
if (sm_arch_ >= 90) { devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
// quick estimate of dp workspace size devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
size_t max_seqlen_div_up_q = ((max_seqlen_q + 64 - 1) / 64) * 64;
size_t max_seqlen_div_up_kv = ((max_seqlen_kv + 64 - 1) / 64) * 64;
size_t required_dp_workspace =
(batch * num_head * max_seqlen_div_up_q * max_seqlen_div_up_kv * 2 + 1048576 - 1) / 1048576;
// default upper limit for dp workspace 256MB
size_t max_allowed_dp_workspace = 256;
if (required_dp_workspace <= max_allowed_dp_workspace) {
use_workspace_opt = true;
}
use_workspace_opt = transformer_engine::getenv<bool>(
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt);
#if (CUDNN_VERSION < 8906)
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD)
|| (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) {
use_workspace_opt = false;
}
#endif
}
#endif
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen_q, max_seqlen_kv, head_dim,
attn_scale, p_dropout, qkv_layout, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type), workspace->data.dptr, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
&workspace_size, stream, handle, use_workspace_opt); get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
......
...@@ -12,55 +12,82 @@ ...@@ -12,55 +12,82 @@
#define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ #define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
#include <cudnn.h> #include <cudnn.h>
#include "common/common.h" #include "common/common.h"
namespace transformer_engine { namespace transformer_engine {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t head_size, bool is_training, float attn_scale, size_t batch, size_t num_attn_heads, size_t max_seqlen,
float p_dropout, NVTE_QKV_Layout qkv_layout, size_t head_size, bool is_training, float attn_scale,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV, const Tensor *input_Bias, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *input_QKV, const Tensor *input_Bias,
const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t head_dim, float attn_scale, float p_dropout, size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_O, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_dO, Tensor *output_S, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O,
Tensor *output_dQKV, Tensor *output_dBias, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S,
const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv, void fused_attn_arbitrary_seqlen_fwd(
size_t num_head, size_t head_size, bool is_training, size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool is_training, float attn_scale, float p_dropout,
const Tensor *input_Q, const Tensor *input_K, NVTE_QKV_Layout qkv_layout,
const Tensor *input_V, const Tensor *input_Bias, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *input_Q, const Tensor *input_K,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *input_V, const Tensor *input_Bias,
const Tensor *rng_state, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv, void fused_attn_arbitrary_seqlen_bwd(
size_t num_head, size_t head_dim, float attn_scale, size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
float p_dropout, NVTE_QKV_Layout qkv_layout, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q, const Tensor *input_K, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_V, const Tensor *input_O, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_dO, Tensor *output_S, const Tensor *input_V, const Tensor *input_O,
Tensor *output_dQ, Tensor *output_dK, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dV, Tensor *output_dBias, Tensor *output_dQ, Tensor *output_dK,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, Tensor *output_dV, Tensor *output_dBias,
const Tensor *rng_state, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900 #endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -217,14 +217,14 @@ static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int6 ...@@ -217,14 +217,14 @@ static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int6
int64_t maskOutputTensor_virtual = true; int64_t maskOutputTensor_virtual = true;
cudnnDataType_t maskOutputTensor_dataType = CUDNN_DATA_FLOAT; cudnnDataType_t maskOutputTensor_dataType = CUDNN_DATA_FLOAT;
auto maskOutputTensor_reorderType = auto maskOutputTensor_reorderType =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_NONE; cudnn_frontend::TensorReordering_t::NONE;
if (is_bprop) { if (is_bprop) {
maskOutputTensor_id = dS_ID; maskOutputTensor_id = dS_ID;
maskOutputTensor_virtual = false; maskOutputTensor_virtual = false;
maskOutputTensor_dataType = tensorType; maskOutputTensor_dataType = tensorType;
maskOutputTensor_reorderType = maskOutputTensor_reorderType =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16; cudnn_frontend::TensorReordering_t::F16x16;
} }
auto maskOutputTensor = auto maskOutputTensor =
...@@ -357,7 +357,7 @@ static cudnn_frontend::Tensor createSoftmaxForward( ...@@ -357,7 +357,7 @@ static cudnn_frontend::Tensor createSoftmaxForward(
// divide (e/ sum(e)) // divide (e/ sum(e))
auto reorder_type = auto reorder_type =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16; cudnn_frontend::TensorReordering_t::F16x16;
auto afterDivisionTensor = auto afterDivisionTensor =
cudnn_frontend::TensorBuilder() cudnn_frontend::TensorBuilder()
...@@ -448,7 +448,7 @@ static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, i ...@@ -448,7 +448,7 @@ static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, i
afterBMM1_stride, true, false); // is virtual afterBMM1_stride, true, false); // is virtual
auto reorder_type = auto reorder_type =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16; cudnn_frontend::TensorReordering_t::F16x16;
// after dropout tensor // after dropout tensor
auto afterDropoutTensor = auto afterDropoutTensor =
...@@ -918,7 +918,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -918,7 +918,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
auto doTensor = tensor_create(tensorType, dO_ID, o_dim, o_stride, false, false); auto doTensor = tensor_create(tensorType, dO_ID, o_dim, o_stride, false, false);
auto reorder_type = auto reorder_type =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16; cudnn_frontend::TensorReordering_t::F16x16;
// activation from fprop // activation from fprop
auto pTensor = auto pTensor =
...@@ -1246,7 +1246,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -1246,7 +1246,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
using namespace transformer_engine::fused_attn; using namespace transformer_engine::fused_attn;
void fused_attn_max_512_fwd_qkvpacked( void fused_attn_max_512_fwd_qkvpacked(
size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, bool is_training, size_t batch, size_t num_head, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
...@@ -1312,8 +1312,8 @@ void fused_attn_max_512_fwd_qkvpacked( ...@@ -1312,8 +1312,8 @@ void fused_attn_max_512_fwd_qkvpacked(
} }
} }
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t num_head, size_t head_dim, bool is_training, size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Q, const Tensor *input_KV,
...@@ -1389,8 +1389,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1389,8 +1389,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_ERROR("Unexpected workspace_size."); NVTE_ERROR("Unexpected workspace_size.");
} }
} }
void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t num_head, size_t head_dim, bool is_training, size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_Q, const Tensor *input_K,
...@@ -1460,7 +1460,7 @@ void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seq ...@@ -1460,7 +1460,7 @@ void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seq
} }
} }
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_dim, float attn_scale, float p_dropout, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, NVTE_Mask_Type mask_type, const Tensor *input_QKV,
...@@ -1519,8 +1519,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -1519,8 +1519,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
} }
} }
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale, size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Q, const Tensor *input_KV,
...@@ -1580,8 +1580,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1580,8 +1580,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_ERROR("Unexpected workspace_size."); NVTE_ERROR("Unexpected workspace_size.");
} }
} }
void fused_attn_max_512_bwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale, size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_Q, const Tensor *input_K,
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace transformer_engine { namespace transformer_engine {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_size, bool is_training, float attn_scale, size_t head_size, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
...@@ -28,8 +28,8 @@ void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -28,8 +28,8 @@ void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
const Tensor *cu_seqlens, const Tensor *rng_state, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t num_head, size_t head_dim, bool is_training, size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Q, const Tensor *input_KV,
...@@ -38,8 +38,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -38,8 +38,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
const Tensor *kv_cu_seqlens, const Tensor *rng_state, const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t num_head, size_t head_dim, bool is_training, size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_Q, const Tensor *input_K,
...@@ -49,7 +49,7 @@ void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seq ...@@ -49,7 +49,7 @@ void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seq
const Tensor *kv_cu_seqlens, const Tensor *rng_state, const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_dim, float attn_scale, float p_dropout, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, NVTE_Mask_Type mask_type, const Tensor *input_QKV,
...@@ -58,8 +58,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -58,8 +58,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
const Tensor *cu_seqlens, Tensor *workspace, const Tensor *cu_seqlens, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle); cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale, size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Q, const Tensor *input_KV,
...@@ -68,8 +68,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -68,8 +68,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale, size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_Q, const Tensor *input_K,
......
...@@ -366,8 +366,7 @@ static cudnn_frontend::Tensor createDropoutForward( ...@@ -366,8 +366,7 @@ static cudnn_frontend::Tensor createDropoutForward(
.setDataType(CUDNN_DATA_FLOAT) .setDataType(CUDNN_DATA_FLOAT)
.setVirtual(true) .setVirtual(true)
.setByValue(false) .setByValue(false)
.setReorderType(cudnn_frontend::cudnnBackendTensorReordering_t:: .setReorderType(cudnn_frontend::TensorReordering_t::F16x16)
CUDNN_TENSOR_REORDERING_F16x16)
.build(); .build();
// Scale after dropout // Scale after dropout
auto scaleDropoutTensor = tensor_create( auto scaleDropoutTensor = tensor_create(
...@@ -448,8 +447,7 @@ static cudnn_frontend::Tensor createDropoutBackward( ...@@ -448,8 +447,7 @@ static cudnn_frontend::Tensor createDropoutBackward(
.setDataType(CUDNN_DATA_FLOAT) .setDataType(CUDNN_DATA_FLOAT)
.setVirtual(true) .setVirtual(true)
.setByValue(false) .setByValue(false)
.setReorderType(cudnn_frontend::cudnnBackendTensorReordering_t:: .setReorderType(cudnn_frontend::TensorReordering_t::F16x16)
CUDNN_TENSOR_REORDERING_F16x16)
.build(); .build();
// Scale after dropout (1 / (1 - p)) // Scale after dropout (1 / (1 - p))
auto scaleDropoutTensor = tensor_create( auto scaleDropoutTensor = tensor_create(
...@@ -992,7 +990,7 @@ static cudnn_frontend::Tensor createdSQBMM( ...@@ -992,7 +990,7 @@ static cudnn_frontend::Tensor createdSQBMM(
} }
// fused attention FWD FP8 // fused attention FWD FP8
void fused_attn_fp8_fwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
bool isTraining, float attnScale, bool isTraining, float attnScale,
float dropoutProbability, NVTE_QKV_Layout layout, float dropoutProbability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrQ, void* devPtrK, void* devPtrV,
...@@ -1305,7 +1303,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in ...@@ -1305,7 +1303,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in
} }
// fused attention BWD FP8 // fused attention BWD FP8
void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
float attnScale, float dropoutProbability, NVTE_QKV_Layout layout, float attnScale, float dropoutProbability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrQ, void* devPtrK, void* devPtrV,
void* devPtrM, void* devPtrZInv, void* devPtrM, void* devPtrZInv,
...@@ -1935,7 +1933,7 @@ void fused_attn_fp8_fwd_qkvpacked( ...@@ -1935,7 +1933,7 @@ void fused_attn_fp8_fwd_qkvpacked(
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn::fused_attn_fp8_fwd_impl( fused_attn::fused_attn_fp8_fwd_impl(
b, max_seqlen, max_seqlen, h, d, b, h, max_seqlen, max_seqlen, d,
is_training, attn_scale, p_dropout, qkv_layout, is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv, devPtrM, devPtrZInv,
...@@ -2025,7 +2023,7 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -2025,7 +2023,7 @@ void fused_attn_fp8_bwd_qkvpacked(
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn::fused_attn_fp8_bwd_impl( fused_attn::fused_attn_fp8_bwd_impl(
b, max_seqlen, max_seqlen, h, d, b, h, max_seqlen, max_seqlen, d,
attn_scale, p_dropout, qkv_layout, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv, devPtrM, devPtrZInv,
...@@ -2131,7 +2129,7 @@ void fused_attn_fp8_fwd( ...@@ -2131,7 +2129,7 @@ void fused_attn_fp8_fwd(
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn::fused_attn_fp8_fwd_impl( fused_attn::fused_attn_fp8_fwd_impl(
b, max_seqlen_q, max_seqlen_kv, h, d, b, h, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, p_dropout, qkv_layout, is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv, devPtrM, devPtrZInv,
...@@ -2224,7 +2222,7 @@ void fused_attn_fp8_bwd( ...@@ -2224,7 +2222,7 @@ void fused_attn_fp8_bwd(
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn::fused_attn_fp8_bwd_impl( fused_attn::fused_attn_fp8_bwd_impl(
b, max_seqlen_q, max_seqlen_kv, h, d, b, h, max_seqlen_q, max_seqlen_kv, d,
attn_scale, p_dropout, qkv_layout, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv, devPtrM, devPtrZInv,
......
...@@ -14,8 +14,7 @@ namespace transformer_engine { ...@@ -14,8 +14,7 @@ namespace transformer_engine {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV // fused attention FWD FP8 with packed QKV
void fused_attn_fp8_fwd_qkvpacked( void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t b, size_t h, size_t max_seqlen, size_t d,
size_t h, size_t d,
bool is_training, float attn_scale, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV, const Tensor *input_QKV,
...@@ -30,8 +29,7 @@ void fused_attn_fp8_fwd_qkvpacked( ...@@ -30,8 +29,7 @@ void fused_attn_fp8_fwd_qkvpacked(
// fused attention BWD FP8 with packed QKV // fused attention BWD FP8 with packed QKV
void fused_attn_fp8_bwd_qkvpacked( void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t b, size_t h, size_t max_seqlen, size_t d,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV, const Tensor *input_QKV,
const Tensor *input_O, const Tensor *input_O,
...@@ -49,8 +47,7 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -49,8 +47,7 @@ void fused_attn_fp8_bwd_qkvpacked(
// fused attention FWD FP8 with separate Q, K, V // fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd( void fused_attn_fp8_fwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
size_t h, size_t d,
bool is_training, float attn_scale, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
...@@ -66,8 +63,7 @@ void fused_attn_fp8_fwd( ...@@ -66,8 +63,7 @@ void fused_attn_fp8_fwd(
// fused attention BWD FP8 with separate Q, K, V // fused attention BWD FP8 with separate Q, K, V
void fused_attn_fp8_bwd( void fused_attn_fp8_bwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_O, const Tensor *input_O,
......
...@@ -30,109 +30,6 @@ void generateMatrixStrides( ...@@ -30,109 +30,6 @@ void generateMatrixStrides(
constexpr int seqlen_q_dim_idx = 2; constexpr int seqlen_q_dim_idx = 2;
constexpr int seqlen_kv_dim_idx = 3; constexpr int seqlen_kv_dim_idx = 3;
// to be deprecated in the future
switch (matrix) {
case NVTE_QKV_Matrix::NVTE_Q_Matrix:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_q * 3 * h * d;
} else if ((layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED)
|| (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_q * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_K_Matrix:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[hidden_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[seqlen_dim_idx] = 2 * h * d;
strideA[hidden_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) {
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[seqlen_transpose_dim_idx] = 3 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[seqlen_transpose_dim_idx] = 2 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) {
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_V_Matrix:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = 2* h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[hidden_transpose_dim_idx] = 1;
strideA[seqlen_transpose_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[hidden_transpose_dim_idx] = 1;
strideA[seqlen_transpose_dim_idx] = 2* h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) {
strideA[hidden_transpose_dim_idx] = 1;
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_S_Matrix:
strideA[seqlen_kv_dim_idx] = 1;
strideA[seqlen_q_dim_idx] = s_kv;
strideA[head_dim_idx] = s_q * s_kv;
strideA[batch_dim_idx] = h * s_q * s_kv;
break;
case NVTE_QKV_Matrix::NVTE_O_Matrix:
strideA[seqlen_kv_dim_idx] = 1;
strideA[seqlen_q_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_q * h * d;
break;
}
// new way of getting strides
switch (layout) { switch (layout) {
case NVTE_QKV_Layout::NVTE_SB3HD: case NVTE_QKV_Layout::NVTE_SB3HD:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
...@@ -497,4 +394,27 @@ cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) { ...@@ -497,4 +394,27 @@ cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) {
NVTE_ERROR("Invalid cuDNN data type. \n"); NVTE_ERROR("Invalid cuDNN data type. \n");
} }
} }
// get cuDNN data type
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kInt32:
return cudnn_frontend::DataType_t::INT32;
case DType::kInt64:
return cudnn_frontend::DataType_t::INT64;
case DType::kFloat16:
return cudnn_frontend::DataType_t::HALF;
case DType::kFloat32:
return cudnn_frontend::DataType_t::FLOAT;
case DType::kBFloat16:
return cudnn_frontend::DataType_t::BFLOAT16;
case DType::kFloat8E4M3:
return cudnn_frontend::DataType_t::FP8_E4M3;
case DType::kFloat8E5M2:
return cudnn_frontend::DataType_t::FP8_E5M2;
default:
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}
} // namespace transformer_engine } // namespace transformer_engine
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <cudnn.h> #include <cudnn.h>
#include <cudnn_frontend.h> #include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#include <cstdint> #include <cstdint>
#include <mutex> #include <mutex>
...@@ -95,6 +96,34 @@ struct FADescriptor { ...@@ -95,6 +96,34 @@ struct FADescriptor {
} }
}; };
struct FADescriptor_v1 {
std::int64_t b;
std::int64_t h;
std::int64_t hg;
std::int64_t s_q;
std::int64_t s_kv;
std::int64_t d;
float attnScale;
bool isTraining;
float dropoutProbability;
NVTE_QKV_Layout layout;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
cudnn_frontend::DataType_t tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d,
attnScale, isTraining, dropoutProbability,
layout, mask_type, bias_type, tensor_type)
< std::tie(
rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d,
rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.bias_type,
rhs.tensor_type);
}
};
__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, __global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d,
int32_t *cu_seqlens_q, int32_t *actual_seqlens_q, int32_t *cu_seqlens_q, int32_t *actual_seqlens_q,
int32_t *qkv_ragged_offset, int32_t *o_ragged_offset); int32_t *qkv_ragged_offset, int32_t *o_ragged_offset);
...@@ -107,6 +136,7 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, ...@@ -107,6 +136,7 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b,
} // namespace fused_attn } // namespace fused_attn
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
class cudnnExecutionPlanManager { class cudnnExecutionPlanManager {
public: public:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment