Unverified Commit 32db3928 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Integrate cuDNN frontend v1 to fused attention (#497)



* Integrate cuDNN frontend v1 to fused attention and miscellaneous fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax/paddle for unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax/pytorch lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* simplify stride generation
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix and/or logic in get_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix flag_max512 and test_numerics
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove v.contiguous() since get_qkv_layout covers it
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* skip fp8 tests for sm89
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further fix jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert mask type to comma-separated list
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last two commits
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* integrate v1/pre-release-5
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* cleanup prerelease5 integration and fix FA2.1 commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force dropout to 0 if not training
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix Jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* testing bias/alibi and padding+causal; add alibi to unfused DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* set flag_arb to false when non determinism is not allowed
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* followup on prev commit; remove redundant python env var setting
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: minor tweaks for tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* prepare for tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix determinism logic for fused attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add bias to bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix gpt_checkpointing/dpa_accuracy problem
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix some seg fault issues
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add failure notes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove use of non-deter var for backend selection
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for lint and CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix workspace size in bwd and uncomment bias test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix get_alibi and remove check_support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update tests status
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove workspace_opt from FADescriptor_v1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable arbitrary backend + post scale bias in Jax; waiting on PR 525
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up bhsd order
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* swap bias/rng_state order in aux_ctx_tensor and add bias to aux_ctx_tensor in _qkvpacked/_kvpacked API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove support for padding_causal + cross for max512
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change alibi bias to float32 for bias_1_4/5 tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further clean up tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix thd fwd output shape for FlashAttention and add backend info for DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix definition of workspace limit when dbias is present
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further tweak DP_WORKSPACE_LIMIT definition
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disallow alibi+no_mask for sdpa flash and update alibi tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update jax/paddle after PR525 and fix DP_WORKSPACE_LIMIT for dbias Jax tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable dbias for non-hopper archs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix layernorm lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remode unused arg for lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove build dir in setup.py
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change selection logic to prefer fused attn on sm90
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix distributed jax test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix h and s order in header
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update to cudnn fe v1 branch
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove manual setting of workopt path due to dbias after v1 update
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix paddle CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add post_scale_bias and alibi to sdpa flash support matrix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix support matrix in header files
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move headers back to .cu and change seed/offset to int64
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update Megatron commit in L1 test and remove all prints in fused attn test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix L1 Megatron test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 arg in L1 Megatron script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* print only when debug flag is on
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove checkpointing loading to avoid loading other tests results
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent ff760a9d
Subproject commit 12f35fa2be5994c1106367cac2fba21457b064f4 Subproject commit 9f82dda5c029d15a5f371f0fe003dc0c74a0c987
...@@ -8,6 +8,7 @@ set -e ...@@ -8,6 +8,7 @@ set -e
git clone https://github.com/NVIDIA/Megatron-LM.git git clone https://github.com/NVIDIA/Megatron-LM.git
cd Megatron-LM cd Megatron-LM
git checkout f24fac4ed0dcf0522056521a93445d9a82f501a9 git checkout bcce6f54e075e3c3374ea67adefe54f3f2da2b07
sed -i -e '1504,1505d' megatron/model/transformer.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_convergence.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_convergence.py
python $TE_PATH/tests/pytorch/distributed/print_logs.py python $TE_PATH/tests/pytorch/distributed/print_logs.py
...@@ -77,10 +77,11 @@ class TestDistributedSelfAttn: ...@@ -77,10 +77,11 @@ class TestDistributedSelfAttn:
is_training = True is_training = True
scaling_factor = 1.0 scaling_factor = 1.0
_, seqlen, _, _, hidden = data_shape _, seqlen, _, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type, if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
attn_mask_type, dropout_prob, seqlen, seqlen, hidden): attn_mask_type, dropout_prob, num_head, num_head,
seqlen, seqlen, hidden):
pytest.skip(f"No FusedAttn backwend found") pytest.skip(f"No FusedAttn backwend found")
def target_func(qkv, bias, mask): def target_func(qkv, bias, mask):
...@@ -182,10 +183,11 @@ class TestDistributedCrossAttn: ...@@ -182,10 +183,11 @@ class TestDistributedCrossAttn:
is_training = True is_training = True
scaling_factor = 1.0 scaling_factor = 1.0
_, seqlen, _, hidden = data_shape _, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type, if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type,
attn_mask_type, dropout_prob, seqlen, seqlen, hidden): attn_mask_type, dropout_prob, num_head, num_head,
seqlen, seqlen, hidden):
pytest.skip(f"No FusedAttn backwend found") pytest.skip(f"No FusedAttn backwend found")
def target_func(q, kv, mask): def target_func(q, kv, mask):
......
...@@ -180,12 +180,14 @@ class TestSelfFusedAttn(): ...@@ -180,12 +180,14 @@ class TestSelfFusedAttn():
@staticmethod @staticmethod
def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype, def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype,
head_dim): num_heads_q, num_heads_kv, head_dim):
assert isinstance(backend, Backend) assert isinstance(backend, Backend)
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type, if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
attn_mask_type, dropout_probability, s, s, head_dim): attn_mask_type, dropout_probability,
num_heads_q, num_heads_kv,
s, s, head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.") pytest.skip("Unsupported inputs combination or device compute capability.")
def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend, def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
...@@ -197,6 +199,8 @@ class TestSelfFusedAttn(): ...@@ -197,6 +199,8 @@ class TestSelfFusedAttn():
backend=backend, backend=backend,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
dtype=dtype, dtype=dtype,
num_heads_q=h,
num_heads_kv=h,
head_dim=d) head_dim=d)
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
......
...@@ -48,11 +48,8 @@ class TestGroupSharding(unittest.TestCase): ...@@ -48,11 +48,8 @@ class TestGroupSharding(unittest.TestCase):
def _get_model_and_optimizer(self, model, stage): def _get_model_and_optimizer(self, model, stage):
if stage == 1: if stage == 1:
optimizer = DygraphShardingOptimizer( optimizer = DygraphShardingOptimizer(
hcg=fleet.get_hybrid_communicate_group(), paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters()),
user_defined_strategy=self.strategy, fleet.get_hybrid_communicate_group(),
params=model.parameters(),
inner_optimizer_class=paddle.optimizer.AdamW,
learning_rate=0.01,
) )
model = fleet.distributed_model(model) model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer) optimizer = fleet.distributed_optimizer(optimizer)
......
...@@ -634,9 +634,11 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, ...@@ -634,9 +634,11 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
# Skip if cuDNN fused attention is not supported # Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported( if not is_fused_attention_supported(
head_size=head_size, num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen, q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
head_size=head_size,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bs3hd" if attn_type == "self" else "bshd_bs2hd", qkv_layout="bs3hd" if attn_type == "self" else "bshd_bs2hd",
...@@ -762,9 +764,11 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -762,9 +764,11 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
# Skip if cuDNN fused attention is not supported # Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported( if not is_fused_attention_supported(
head_size=hidden_size // num_heads, num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen, q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bs3hd", qkv_layout="bs3hd",
...@@ -940,9 +944,11 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -940,9 +944,11 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
# Skip if cuDNN fused attention is not supported # Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported( if not is_fused_attention_supported(
head_size=hidden_size // num_heads, num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen, q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bs3hd", qkv_layout="bs3hd",
...@@ -952,6 +958,8 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -952,6 +958,8 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
pytest.skip("cuDNN fused attention is not supported") pytest.skip("cuDNN fused attention is not supported")
if not is_fused_attention_supported( if not is_fused_attention_supported(
head_size=hidden_size // num_heads, head_size=hidden_size // num_heads,
num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen, q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
dtype=math_dtype, dtype=math_dtype,
......
...@@ -688,9 +688,11 @@ class TestFusedAttn: ...@@ -688,9 +688,11 @@ class TestFusedAttn:
else "bshd_bs2hd" else "bshd_bs2hd"
) )
fused_attention_backend = get_fused_attention_backend( fused_attention_backend = get_fused_attention_backend(
head_size=self.head_size, num_heads=self.num_heads,
num_gqa_groups=self.num_heads,
q_seqlen=self.q_seqlen, q_seqlen=self.q_seqlen,
kv_seqlen=self.kv_seqlen, kv_seqlen=self.kv_seqlen,
head_size=self.head_size,
dtype=self.dtype, dtype=self.dtype,
dropout=self.dropout_prob, dropout=self.dropout_prob,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -774,9 +776,11 @@ class TestFusedAttn: ...@@ -774,9 +776,11 @@ class TestFusedAttn:
test self attention forward + backward test self attention forward + backward
""" """
if not is_fused_attention_supported( if not is_fused_attention_supported(
head_size=d, num_heads=h,
num_gqa_groups=h,
q_seqlen=s, q_seqlen=s,
kv_seqlen=s, kv_seqlen=s,
head_size=d,
dtype=dtype, dtype=dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bs3hd", qkv_layout="bs3hd",
...@@ -799,9 +803,11 @@ class TestFusedAttn: ...@@ -799,9 +803,11 @@ class TestFusedAttn:
test cross attention forward + backward test cross attention forward + backward
""" """
if not is_fused_attention_supported( if not is_fused_attention_supported(
head_size=d, num_heads=h,
num_gqa_groups=h,
q_seqlen=s_q, q_seqlen=s_q,
kv_seqlen=s_kv, kv_seqlen=s_kv,
head_size=d,
dtype=dtype, dtype=dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bshd_bs2hd", qkv_layout="bshd_bs2hd",
...@@ -825,9 +831,11 @@ class TestFusedAttn: ...@@ -825,9 +831,11 @@ class TestFusedAttn:
test flash attention forward + backward test flash attention forward + backward
""" """
if not is_fused_attention_supported( if not is_fused_attention_supported(
head_size=d, num_heads=h,
num_gqa_groups=h,
q_seqlen=s, q_seqlen=s,
kv_seqlen=s, kv_seqlen=s,
head_size=d,
dtype=dtype, dtype=dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bs3hd", qkv_layout="bs3hd",
......
...@@ -102,9 +102,11 @@ def set_random_seed(seed): ...@@ -102,9 +102,11 @@ def set_random_seed(seed):
def get_fused_attention_backend( def get_fused_attention_backend(
head_size: int, num_heads: int,
num_gqa_groups: int,
q_seqlen: int, q_seqlen: int,
kv_seqlen: int, kv_seqlen: int,
head_size: int,
dtype: Union[paddle.dtype, str], dtype: Union[paddle.dtype, str],
dropout: float, dropout: float,
qkv_layout: str = "bs3hd", qkv_layout: str = "bs3hd",
...@@ -125,6 +127,8 @@ def get_fused_attention_backend( ...@@ -125,6 +127,8 @@ def get_fused_attention_backend(
AttnBiasType[bias_type], AttnBiasType[bias_type],
AttnMaskType[mask_type], AttnMaskType[mask_type],
dropout, dropout,
num_heads,
num_gqa_groups,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
head_size, head_size,
...@@ -132,9 +136,11 @@ def get_fused_attention_backend( ...@@ -132,9 +136,11 @@ def get_fused_attention_backend(
def is_fused_attention_supported( def is_fused_attention_supported(
head_size: int, num_heads: int,
num_gqa_groups: int,
q_seqlen: int, q_seqlen: int,
kv_seqlen: int, kv_seqlen: int,
head_size: int,
dtype: Union[paddle.dtype, str], dtype: Union[paddle.dtype, str],
dropout: float, dropout: float,
qkv_layout: str = "bs3hd", qkv_layout: str = "bs3hd",
...@@ -143,9 +149,11 @@ def is_fused_attention_supported( ...@@ -143,9 +149,11 @@ def is_fused_attention_supported(
) -> bool: ) -> bool:
"""Check if cuDNN fused attention is supported for attention config""" """Check if cuDNN fused attention is supported for attention config"""
backend = get_fused_attention_backend( backend = get_fused_attention_backend(
head_size=head_size, num_heads=num_heads,
num_gqa_groups=num_gqa_groups,
q_seqlen=q_seqlen, q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
head_size=head_size,
dtype=dtype, dtype=dtype,
dropout=dropout, dropout=dropout,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
......
...@@ -81,7 +81,6 @@ options=" \ ...@@ -81,7 +81,6 @@ options=" \
--merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt \ --merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt \
--save-interval ${SAVE_INTERVAL} \ --save-interval ${SAVE_INTERVAL} \
--save ${CHECKPOINT_DIR} \ --save ${CHECKPOINT_DIR} \
--load ${CHECKPOINT_DIR} \
--split ${SPLIT} \ --split ${SPLIT} \
--clip-grad ${CLIP_GRAD} \ --clip-grad ${CLIP_GRAD} \
--weight-decay ${WEIGHT_DECAY} \ --weight-decay ${WEIGHT_DECAY} \
...@@ -90,8 +89,6 @@ options=" \ ...@@ -90,8 +89,6 @@ options=" \
--init-method-std ${INIT_METHOD_STD} \ --init-method-std ${INIT_METHOD_STD} \
--log-params-norm \ --log-params-norm \
--log-num-zeros-in-grad \ --log-num-zeros-in-grad \
--no-query-key-layer-scaling \
--DDP-impl local \
--transformer-impl ${TRANSFORMER_IMPL} \ --transformer-impl ${TRANSFORMER_IMPL} \
--tensorboard-dir ${TENSORBOARD_DIR} \ --tensorboard-dir ${TENSORBOARD_DIR} \
--fp8-margin 0 \ --fp8-margin 0 \
...@@ -108,7 +105,7 @@ if [[ "$WGRAD_FUSION" == "False" ]]; then ...@@ -108,7 +105,7 @@ if [[ "$WGRAD_FUSION" == "False" ]]; then
fi fi
if [[ "$FP8" != "False" ]]; then if [[ "$FP8" != "False" ]]; then
options+=" --fp8-${FP8}" options+=" --fp8-format ${FP8}"
fi fi
if [[ "$DTYPE" != "fp32" ]]; then if [[ "$DTYPE" != "fp32" ]]; then
......
This diff is collapsed.
...@@ -508,6 +508,7 @@ def _test_e2e_checkpointing_get_model(config, dtype): ...@@ -508,6 +508,7 @@ def _test_e2e_checkpointing_get_model(config, dtype):
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
return ( return (
TransformerLayer( TransformerLayer(
config.hidden_size, config.hidden_size,
...@@ -524,7 +525,6 @@ def _test_e2e_checkpointing_get_model(config, dtype): ...@@ -524,7 +525,6 @@ def _test_e2e_checkpointing_get_model(config, dtype):
params_dtype=dtype, params_dtype=dtype,
) )
.cuda() .cuda()
.eval()
) )
...@@ -559,9 +559,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= ...@@ -559,9 +559,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
if p.requires_grad: if p.requires_grad:
param_grads.append(p.grad.clone()) param_grads.append(p.grad.clone())
global _cpu_rng_state, _cuda_rng_state
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
del block del block
block = _test_e2e_checkpointing_get_model(config, dtype) block = _test_e2e_checkpointing_get_model(config, dtype)
block.load_state_dict(torch.load(path)) block.load_state_dict(torch.load(path))
reset_rng_states()
for p in block.parameters(): for p in block.parameters():
if p.requires_grad: if p.requires_grad:
...@@ -815,21 +820,19 @@ def test_dpa_accuracy(dtype, bs, model): ...@@ -815,21 +820,19 @@ def test_dpa_accuracy(dtype, bs, model):
DotProductAttention( DotProductAttention(
config.num_attention_heads, config.num_attention_heads,
config.embed, config.embed,
attention_dropout=0.1, # dropout attention_dropout=0.0, # disable dropout, FU uses rng differently
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
.eval()
) )
torch_dpa = ( torch_dpa = (
TorchDotProductAttention( TorchDotProductAttention(
config.embed, config.embed,
0.1, # dropout 0.0, # dropout
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
.eval()
) )
te_outputs = _test_dpa_accuracy(te_dpa, bs, dtype, config) te_outputs = _test_dpa_accuracy(te_dpa, bs, dtype, config)
......
...@@ -12,55 +12,82 @@ ...@@ -12,55 +12,82 @@
#define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ #define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
#include <cudnn.h> #include <cudnn.h>
#include "common/common.h" #include "common/common.h"
namespace transformer_engine { namespace transformer_engine {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t head_size, bool is_training, float attn_scale, size_t batch, size_t num_attn_heads, size_t max_seqlen,
float p_dropout, NVTE_QKV_Layout qkv_layout, size_t head_size, bool is_training, float attn_scale,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV, const Tensor *input_Bias, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *input_QKV, const Tensor *input_Bias,
const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t head_dim, float attn_scale, float p_dropout, size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_O, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_dO, Tensor *output_S, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O,
Tensor *output_dQKV, Tensor *output_dBias, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S,
const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv, void fused_attn_arbitrary_seqlen_fwd(
size_t num_head, size_t head_size, bool is_training, size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool is_training, float attn_scale, float p_dropout,
const Tensor *input_Q, const Tensor *input_K, NVTE_QKV_Layout qkv_layout,
const Tensor *input_V, const Tensor *input_Bias, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *input_Q, const Tensor *input_K,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *input_V, const Tensor *input_Bias,
const Tensor *rng_state, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv, void fused_attn_arbitrary_seqlen_bwd(
size_t num_head, size_t head_dim, float attn_scale, size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
float p_dropout, NVTE_QKV_Layout qkv_layout, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q, const Tensor *input_K, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_V, const Tensor *input_O, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_dO, Tensor *output_S, const Tensor *input_V, const Tensor *input_O,
Tensor *output_dQ, Tensor *output_dK, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dV, Tensor *output_dBias, Tensor *output_dQ, Tensor *output_dK,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, Tensor *output_dV, Tensor *output_dBias,
const Tensor *rng_state, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900 #endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -217,14 +217,14 @@ static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int6 ...@@ -217,14 +217,14 @@ static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int6
int64_t maskOutputTensor_virtual = true; int64_t maskOutputTensor_virtual = true;
cudnnDataType_t maskOutputTensor_dataType = CUDNN_DATA_FLOAT; cudnnDataType_t maskOutputTensor_dataType = CUDNN_DATA_FLOAT;
auto maskOutputTensor_reorderType = auto maskOutputTensor_reorderType =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_NONE; cudnn_frontend::TensorReordering_t::NONE;
if (is_bprop) { if (is_bprop) {
maskOutputTensor_id = dS_ID; maskOutputTensor_id = dS_ID;
maskOutputTensor_virtual = false; maskOutputTensor_virtual = false;
maskOutputTensor_dataType = tensorType; maskOutputTensor_dataType = tensorType;
maskOutputTensor_reorderType = maskOutputTensor_reorderType =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16; cudnn_frontend::TensorReordering_t::F16x16;
} }
auto maskOutputTensor = auto maskOutputTensor =
...@@ -357,7 +357,7 @@ static cudnn_frontend::Tensor createSoftmaxForward( ...@@ -357,7 +357,7 @@ static cudnn_frontend::Tensor createSoftmaxForward(
// divide (e/ sum(e)) // divide (e/ sum(e))
auto reorder_type = auto reorder_type =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16; cudnn_frontend::TensorReordering_t::F16x16;
auto afterDivisionTensor = auto afterDivisionTensor =
cudnn_frontend::TensorBuilder() cudnn_frontend::TensorBuilder()
...@@ -448,7 +448,7 @@ static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, i ...@@ -448,7 +448,7 @@ static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, i
afterBMM1_stride, true, false); // is virtual afterBMM1_stride, true, false); // is virtual
auto reorder_type = auto reorder_type =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16; cudnn_frontend::TensorReordering_t::F16x16;
// after dropout tensor // after dropout tensor
auto afterDropoutTensor = auto afterDropoutTensor =
...@@ -918,7 +918,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -918,7 +918,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
auto doTensor = tensor_create(tensorType, dO_ID, o_dim, o_stride, false, false); auto doTensor = tensor_create(tensorType, dO_ID, o_dim, o_stride, false, false);
auto reorder_type = auto reorder_type =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16; cudnn_frontend::TensorReordering_t::F16x16;
// activation from fprop // activation from fprop
auto pTensor = auto pTensor =
...@@ -1246,7 +1246,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -1246,7 +1246,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
using namespace transformer_engine::fused_attn; using namespace transformer_engine::fused_attn;
void fused_attn_max_512_fwd_qkvpacked( void fused_attn_max_512_fwd_qkvpacked(
size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, bool is_training, size_t batch, size_t num_head, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
...@@ -1312,8 +1312,8 @@ void fused_attn_max_512_fwd_qkvpacked( ...@@ -1312,8 +1312,8 @@ void fused_attn_max_512_fwd_qkvpacked(
} }
} }
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t num_head, size_t head_dim, bool is_training, size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Q, const Tensor *input_KV,
...@@ -1389,8 +1389,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1389,8 +1389,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_ERROR("Unexpected workspace_size."); NVTE_ERROR("Unexpected workspace_size.");
} }
} }
void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t num_head, size_t head_dim, bool is_training, size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_Q, const Tensor *input_K,
...@@ -1460,7 +1460,7 @@ void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seq ...@@ -1460,7 +1460,7 @@ void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seq
} }
} }
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_dim, float attn_scale, float p_dropout, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, NVTE_Mask_Type mask_type, const Tensor *input_QKV,
...@@ -1519,8 +1519,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -1519,8 +1519,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
} }
} }
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale, size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Q, const Tensor *input_KV,
...@@ -1580,8 +1580,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1580,8 +1580,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_ERROR("Unexpected workspace_size."); NVTE_ERROR("Unexpected workspace_size.");
} }
} }
void fused_attn_max_512_bwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale, size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_Q, const Tensor *input_K,
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace transformer_engine { namespace transformer_engine {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_size, bool is_training, float attn_scale, size_t head_size, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
...@@ -28,8 +28,8 @@ void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -28,8 +28,8 @@ void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
const Tensor *cu_seqlens, const Tensor *rng_state, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t num_head, size_t head_dim, bool is_training, size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Q, const Tensor *input_KV,
...@@ -38,8 +38,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -38,8 +38,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
const Tensor *kv_cu_seqlens, const Tensor *rng_state, const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t num_head, size_t head_dim, bool is_training, size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_Q, const Tensor *input_K,
...@@ -49,7 +49,7 @@ void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seq ...@@ -49,7 +49,7 @@ void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seq
const Tensor *kv_cu_seqlens, const Tensor *rng_state, const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_dim, float attn_scale, float p_dropout, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, NVTE_Mask_Type mask_type, const Tensor *input_QKV,
...@@ -58,8 +58,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -58,8 +58,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
const Tensor *cu_seqlens, Tensor *workspace, const Tensor *cu_seqlens, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle); cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale, size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Q, const Tensor *input_KV,
...@@ -68,8 +68,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -68,8 +68,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale, size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_Q, const Tensor *input_K,
......
...@@ -14,8 +14,7 @@ namespace transformer_engine { ...@@ -14,8 +14,7 @@ namespace transformer_engine {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV // fused attention FWD FP8 with packed QKV
void fused_attn_fp8_fwd_qkvpacked( void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t b, size_t h, size_t max_seqlen, size_t d,
size_t h, size_t d,
bool is_training, float attn_scale, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV, const Tensor *input_QKV,
...@@ -30,8 +29,7 @@ void fused_attn_fp8_fwd_qkvpacked( ...@@ -30,8 +29,7 @@ void fused_attn_fp8_fwd_qkvpacked(
// fused attention BWD FP8 with packed QKV // fused attention BWD FP8 with packed QKV
void fused_attn_fp8_bwd_qkvpacked( void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t b, size_t h, size_t max_seqlen, size_t d,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV, const Tensor *input_QKV,
const Tensor *input_O, const Tensor *input_O,
...@@ -49,8 +47,7 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -49,8 +47,7 @@ void fused_attn_fp8_bwd_qkvpacked(
// fused attention FWD FP8 with separate Q, K, V // fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd( void fused_attn_fp8_fwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
size_t h, size_t d,
bool is_training, float attn_scale, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
...@@ -66,8 +63,7 @@ void fused_attn_fp8_fwd( ...@@ -66,8 +63,7 @@ void fused_attn_fp8_fwd(
// fused attention BWD FP8 with separate Q, K, V // fused attention BWD FP8 with separate Q, K, V
void fused_attn_fp8_bwd( void fused_attn_fp8_bwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_O, const Tensor *input_O,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment