Unverified Commit 32db3928 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Integrate cuDNN frontend v1 to fused attention (#497)



* Integrate cuDNN frontend v1 to fused attention and miscellaneous fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax/paddle for unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax/pytorch lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* simplify stride generation
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix and/or logic in get_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix flag_max512 and test_numerics
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove v.contiguous() since get_qkv_layout covers it
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* skip fp8 tests for sm89
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further fix jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert mask type to comma-separated list
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last two commits
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* integrate v1/pre-release-5
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* cleanup prerelease5 integration and fix FA2.1 commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force dropout to 0 if not training
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix Jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* testing bias/alibi and padding+causal; add alibi to unfused DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* set flag_arb to false when non determinism is not allowed
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* followup on prev commit; remove redundant python env var setting
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: minor tweaks for tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* prepare for tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix determinism logic for fused attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add bias to bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix gpt_checkpointing/dpa_accuracy problem
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix some seg fault issues
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add failure notes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove use of non-deter var for backend selection
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for lint and CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix workspace size in bwd and uncomment bias test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix get_alibi and remove check_support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update tests status
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove workspace_opt from FADescriptor_v1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable arbitrary backend + post scale bias in Jax; waiting on PR 525
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up bhsd order
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* swap bias/rng_state order in aux_ctx_tensor and add bias to aux_ctx_tensor in _qkvpacked/_kvpacked API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove support for padding_causal + cross for max512
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change alibi bias to float32 for bias_1_4/5 tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further clean up tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix thd fwd output shape for FlashAttention and add backend info for DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix definition of workspace limit when dbias is present
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further tweak DP_WORKSPACE_LIMIT definition
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disallow alibi+no_mask for sdpa flash and update alibi tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update jax/paddle after PR525 and fix DP_WORKSPACE_LIMIT for dbias Jax tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable dbias for non-hopper archs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix layernorm lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remode unused arg for lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove build dir in setup.py
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change selection logic to prefer fused attn on sm90
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix distributed jax test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix h and s order in header
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update to cudnn fe v1 branch
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove manual setting of workopt path due to dbias after v1 update
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix paddle CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add post_scale_bias and alibi to sdpa flash support matrix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix support matrix in header files
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move headers back to .cu and change seed/offset to int64
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update Megatron commit in L1 test and remove all prints in fused attn test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix L1 Megatron test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 arg in L1 Megatron script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* print only when debug flag is on
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove checkpointing loading to avoid loading other tests results
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent ff760a9d
......@@ -18,109 +18,56 @@ extern "C" {
#endif
/*! \enum NVTE_QKV_Layout
* \brief Memory layouts of QKV tensors
* `S`, `B`, `H`, `D`, and `T` stand for sequence length, batch size, the number of heads,
head size, and the total number of sequences in a batch, i.e. `t = sum(s_i) for i = 0...b-1`.
`SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length
or padded to the same length, and `THD`-based layouts are used when sequences have
different lengths in a batch.
* \note {`NVTE_QKV_INTERLEAVED`, `NVTE_KV_INTERLEAVED` and `NVTE_NOT_INTERLEAVED`
will be deprecated in the next release. Please use their equivalent enums instead, i.e. `NVTE_T3HD`,
`NVTE_THD_T2HD` and `NVTE_THD_THD_THD` when sequences are of variable lengths, and `NVTE_BS3HD`,
`NVTE_BSHD_BS2HD` and `NVTE_BSHD_BSHD_BSHD` when sequences are of equal length or padded
to equal length.}
* \brief Memory layouts of QKV tensors.
* `S`, `B`, `H`, `D`, and `T` stand for sequence length, batch size, number of heads,
* head size, and the total number of sequences in a batch, i.e. `t = sum(s_i) for i = 0...b-1`.
* `SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length
* or padded to the same length, and `THD`-based layouts are used when sequences have
* different lengths in a batch.
*/
enum NVTE_QKV_Layout {
/*! Separate Q, K, V tensors.
\verbatim
Q: [total_seqs_q, num_heads, head_dim]
| Q Q Q ... Q
| \___________ _____________/
total_seqs_q <| \/
| num_heads * head_dim
K: [total_seqs_kv, num_heads, head_dim]
| K K K ... K
| \___________ _____________/
total_seqs_kv <| \/
| num_heads * head_dim
V: [total_seqs_kv, num_heads, head_dim]
| V V V ... V
| \___________ _____________/
total_seqs_kv <| \/
| num_heads * head_dim
\endverbatim
*/
NVTE_NOT_INTERLEAVED = 0,
/*! Packed QKV.
\verbatim
QKV: [total_seqs, 3, num_heads, head_dim]
| Q Q Q ... Q K K K ... K V V V ... V
| \___________ _____________/
total_seqs <| \/
| num_heads * head_dim
\endverbatim
*/
NVTE_QKV_INTERLEAVED = 1,
/*! Q and packed KV.
\verbatim
Q: [total_seqs_q, num_heads, head_dim]
| Q Q Q ... Q
| \___________ _____________/
total_seqs_q <| \/
| num_heads * head_dim
KV: [total_seqs_kv, 2, num_heads, head_dim]
| K K K ... K V V V ... V
| \___________ _____________/
total_seqs_kv <| \/
| num_heads * head_dim
\endverbatim
*/
NVTE_KV_INTERLEAVED = 2,
NVTE_SB3HD = 3,
NVTE_SBH3D = 4,
NVTE_SBHD_SB2HD = 5,
NVTE_SBHD_SBH2D = 6,
NVTE_SBHD_SBHD_SBHD = 7,
NVTE_BS3HD = 8,
NVTE_BSH3D = 9,
NVTE_BSHD_BS2HD = 10,
NVTE_BSHD_BSH2D = 11,
NVTE_BSHD_BSHD_BSHD = 12,
NVTE_T3HD = 13,
NVTE_TH3D = 14,
NVTE_THD_T2HD = 15,
NVTE_THD_TH2D = 16,
NVTE_THD_THD_THD = 17,
NVTE_SB3HD = 0, /*!< SB3HD layout */
NVTE_SBH3D = 1, /*!< SBH3D layout */
NVTE_SBHD_SB2HD = 2, /*!< SBHD_SB2HD layout */
NVTE_SBHD_SBH2D = 3, /*!< SBHD_SBH2D layout */
NVTE_SBHD_SBHD_SBHD = 4, /*!< SBHD_SBHD_SBHD layout */
NVTE_BS3HD = 5, /*!< BS3HD layout */
NVTE_BSH3D = 6, /*!< BSH3D layout */
NVTE_BSHD_BS2HD = 7, /*!< BSHD_BS2HD layout */
NVTE_BSHD_BSH2D = 8, /*!< BSHD_BSH2D layout */
NVTE_BSHD_BSHD_BSHD = 9, /*!< BSHD_BSHD_BSHD layout */
NVTE_T3HD = 10, /*!< T3HD layout */
NVTE_TH3D = 11, /*!< TH3D layout */
NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */
NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */
NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */
};
/*! \enum NVTE_QKV_Layout_Group
* \brief Grouping of QKV layouts
* \brief QKV layout groups
*/
enum NVTE_QKV_Layout_Group {
/*! 3HD QKV layouts, e.g. BS3HD */
/*! 3HD QKV layouts, i.e. BS3HD, SB3HD, T3HD */
NVTE_3HD = 0,
/*! H3D QKV layouts, e.g. BSH3D */
/*! H3D QKV layouts, i.e. BSH3D, SBH3D, TH3D */
NVTE_H3D = 1,
/*! HD_2HD QKV layouts, e.g. BSHD_BS2HD */
/*! HD_2HD QKV layouts, i.e. BSHD_BS2HD, SBHD_SB2HD, THD_T2HD */
NVTE_HD_2HD = 2,
/*! HD_H2D QKV layouts, e.g. BSHD_BSH2D */
/*! HD_H2D QKV layouts, i.e. BSHD_BSH2D, SBHD_SBH2D, THD_TH2D */
NVTE_HD_H2D = 3,
/*! HD_HD_HD QKV layouts, e.g. BSHD_BSHD_BSHD */
/*! HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD */
NVTE_HD_HD_HD = 4,
};
/*! \enum NVTE_QKV_Format
* \brief Dimension formats for QKV tensors
* \brief QKV formats
*/
enum NVTE_QKV_Format {
/*! SBHD QKV format */
/*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD */
NVTE_SBHD = 0,
/*! BSHD QKV format */
/*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD */
NVTE_BSHD = 1,
/*! THD QKV format */
/*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */
NVTE_THD = 2,
};
......@@ -133,7 +80,9 @@ enum NVTE_Bias_Type {
/*! Bias before scale */
NVTE_PRE_SCALE_BIAS = 1,
/*! Bias after scale */
NVTE_POST_SCALE_BIAS = 2
NVTE_POST_SCALE_BIAS = 2,
/*! ALiBi */
NVTE_ALIBI = 3,
};
/*! \enum NVTE_Mask_Type
......@@ -164,7 +113,7 @@ enum NVTE_Fused_Attn_Backend {
NVTE_FP8 = 2,
};
/*! \brief Get layout group for a given QKV layout
/*! \brief Get QKV layout group for a given QKV layout.
*
* \param[in] qkv_layout QKV layout, e.g. sbh3d.
*
......@@ -172,7 +121,7 @@ enum NVTE_Fused_Attn_Backend {
*/
NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout);
/*! \brief Get QKV format for a given QKV layout
/*! \brief Get QKV format for a given QKV layout.
*
* \param[in] qkv_layout QKV layout, e.g. sbh3d.
*
......@@ -188,6 +137,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V.
* \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim The head dimension of Q, K, V.
......@@ -198,8 +149,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim);
float dropout,
size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv,
size_t head_dim);
/*! \brief Compute dot product attention with packed QKV input.
*
......@@ -211,14 +164,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
*
* Support Matrix:
\verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | BS3HD,SB3HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 |
| 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
| 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
\endverbatim
*
* \param[in] QKV The QKV tensor in packed format,
* [total_seqs, 3, num_heads, head_dim].
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
......@@ -256,14 +208,13 @@ void nvte_fused_attn_fwd_qkvpacked(
*
* Support Matrix:
\verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | BS3HD,SB3HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 |
| 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
| 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
\endverbatim
*
* \param[in] QKV The QKV tensor in packed format,
* [total_seqs, 3, num_heads, head_dim].
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
......@@ -310,12 +261,13 @@ void nvte_fused_attn_bwd_qkvpacked(
*
* Support Matrix:
\verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 |
| 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
\endverbatim
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
* \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in 2HD or H2D layouts.
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
......@@ -358,12 +310,13 @@ void nvte_fused_attn_fwd_kvpacked(
*
* Support Matrix:
\verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 |
| 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
\endverbatim
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
* \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in H2D or 2HD layouts.
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
......@@ -417,10 +370,12 @@ void nvte_fused_attn_bwd_kvpacked(
*
* Support Matrix:
\verbatim
| backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 |
| 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
| | | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | | | | | |
| | | BSHD_BSHD_BSHD,SBHD_SBHD_SBHD | | | | | |
| 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor.
......@@ -469,10 +424,12 @@ void nvte_fused_attn_fwd(
*
* Support Matrix:
\verbatim
| backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 |
| 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
| | | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | | | | | |
| | | BSHD_BSHD_BSHD,SBHD_SBHD_SBHD | | | | | |
| 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor.
......
......@@ -1637,6 +1637,8 @@ class FusedAttnHelper:
attn_bias_type: NVTE_Bias_Type
attn_mask_type: NVTE_Mask_Type
dropout_probability: float
num_heads_q: int
num_heads_kv: int
max_seqlen_q: int
max_seqlen_kv: int
head_dim: int
......@@ -1652,6 +1654,7 @@ class FusedAttnHelper:
self.qkv_layout, self.attn_bias_type,
self.attn_mask_type,
self.dropout_probability,
self.num_heads_q, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_kv,
self.head_dim)
......@@ -1733,8 +1736,8 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
output_dtype = qkv_dtype
backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type,
attn_mask_type, dropout_probability, max_seqlen, max_seqlen,
head_dim).get_fused_attn_backend()
attn_mask_type, dropout_probability, num_head, num_head,
max_seqlen, max_seqlen, head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (*batch_shape, num_head, max_seqlen, max_seqlen)
......@@ -2087,8 +2090,9 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
output_dtype = q_dtype
backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
attn_bias_type, attn_mask_type, dropout_probability, q_max_seqlen,
kv_max_seqlen, q_head_dim).get_fused_attn_backend()
attn_bias_type, attn_mask_type, dropout_probability,
q_num_head, kv_num_head,
q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen)
......
......@@ -740,11 +740,13 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim) {
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
mask_type, dropout_probability, q_num_heads, kv_num_heads,
q_max_seqlen, kv_max_seqlen, head_dim);
return backend;
}
......@@ -799,7 +801,8 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
mask_type, dropout_probability, num_head, num_head,
q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors;
......@@ -975,7 +978,8 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
mask_type, dropout_probability, num_head, num_head,
q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors;
......
......@@ -117,6 +117,7 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim);
......
......@@ -444,8 +444,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout,
attn_bias_type, attn_mask_type,
self.dropout_rate, q_seqlen,
kv_seqlen, self.head_dim)
self.dropout_rate,
self.num_heads, self.num_heads,
q_seqlen, kv_seqlen, self.head_dim)
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
......
......@@ -40,13 +40,14 @@ class QKVLayout(Enum):
def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
dropout_probability, max_seqlen_q, max_seqlen_kv, head_dim):
dropout_probability, num_heads_q, num_heads_kv,
max_seqlen_q, max_seqlen_kv, head_dim):
"""
To check whether the fused attention kernel is available
"""
return FusedAttnHelper(q_type, kv_type, qkv_layout.value, attn_bias_type.value,
attn_mask_type.value, dropout_probability, max_seqlen_q, max_seqlen_kv,
head_dim).is_fused_attn_kernel_available()
attn_mask_type.value, dropout_probability, num_heads_q, num_heads_kv,
max_seqlen_q, max_seqlen_kv, head_dim).is_fused_attn_kernel_available()
def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
......
......@@ -128,10 +128,12 @@ inline DType Int2NvteDType(int64_t dtype) {
inline NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim) {
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, max_seqlen_q, max_seqlen_kv, head_dim);
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv, head_dim);
return fused_attention_backend;
}
......
......@@ -237,8 +237,10 @@ class DotProductAttention(paddle.nn.Layer):
self.fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype], TE_DType[query_layer.dtype],
tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type],
AttnMaskType[self.attn_mask_type], self.attention_dropout, max_s_q, max_s_kv,
query_layer.shape[-1])
AttnMaskType[self.attn_mask_type], self.attention_dropout,
query_layer.shape[-2],
key_value_layer.shape[-2] if key_value_layer is not None else query_layer.shape[-2],
max_s_q, max_s_kv, query_layer.shape[-1])
is_backend_avail = (self.fused_attention_backend in [
FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]
......
......@@ -69,6 +69,7 @@ else:
_cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv = None, None, None, None
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
......@@ -119,13 +120,59 @@ class InferenceParams: # pylint: disable=too-few-public-methods
new_inference_value_memory,
)
@torch.no_grad()
def get_alibi(
num_heads: int,
max_seqlen_q: int,
max_seqlen_kv: int,
) -> torch.Tensor:
"""
Generate ALiBi bias in the shape of [1, num_heads, max_seqlen_q, max_seqlen_kv].
"""
n = 2 ** math.floor(math.log2(num_heads))
m_0 = 2.0 ** (-8.0 / n)
m = torch.pow(m_0, torch.arange(1, 1 + n))
if n < num_heads:
m_hat_0 = 2.0 ** (-4.0 / n)
m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
m = torch.cat([m, m_hat])
a = torch.ones(max_seqlen_q, max_seqlen_kv)
b = torch.triu(a,diagonal=1)
c = b.cumsum(dim=-1)
bb = torch.tril(a,diagonal=-1)
cc = bb.cumsum(dim=0)
d = c - cc
bias = d.repeat(1, num_heads, 1, 1)
for i in range(num_heads):
bias[0,i,:,:] = m[i] * bias[0,i,:,:]
bias = bias.to(dtype=torch.float32, device="cuda")
return bias
def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
"""
Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
the samples in a batch.
"""
mask = mask.squeeze(1).squeeze(1)
reduced_mask = mask.sum(dim=1)
cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
zero = torch.zeros(1, dtype=torch.int32, device="cuda")
cu_seqlens = torch.cat((zero, cu_seqlens))
return cu_seqlens
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
tensor of shape [batch_size + 1,] containing the cumulative sequence
lengths of every sample in the batch and the indices containing valid
samples.
tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
containing the indices for the valid tokens.
"""
mask = mask.squeeze(1).squeeze(1)
bs, seqlen = mask.shape
......@@ -147,6 +194,26 @@ def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.
return cu_seqlens, indices
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
"""
Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
the valid tokens in a batch.
"""
bs = len(cu_seqlens) - 1
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
indices = [i*max_seqlen + ii for i,j in enumerate(seqlens) for ii in range(j)]
indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(
dtype=torch.int64, device="cuda")
num_nonzeros = indices.shape[0]
pad_amount = bs * max_seqlen - num_nonzeros
indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount),
mode="constant", value=float(bs * max_seqlen))
return indices
@jit_fuser
def pack_tensor(
indices: torch.Tensor,
......@@ -290,34 +357,6 @@ class UnpackTensor(torch.autograd.Function):
return None, None, pack_tensor(ctx.indices, grad_output)
def _unpack_attn_mask_type(attn_mask_type: str) -> Tuple[str, bool]:
"""
Unpacks the attention mask type string and returns a single mask type
and a boolean for whether to apply causal mask. Also ensures that the
combination of masks passed in is supported by one of the attention
backends available.
"""
mask_types = attn_mask_type.split(',')
assert (
all(mask_type in AttnMaskTypes for mask_type in mask_types)
), f"Mask type {attn_mask_type} is not supported."
# Whether or not to apply causal mask toggle.
causal_mask = False
if "causal" in mask_types:
mask_types.remove("causal")
causal_mask = True
if len(mask_types) == 0: # Only apply causal mask.
return "causal", True
if len(mask_types) == 1 and causal_mask: # Causal + padding masks
assert mask_types[0] == "padding", f"Causal + {mask_types[0]} masking not supported."
return "padding", True
if len(mask_types) == 1: # Arbitrary or padding or no_mask
return mask_types[0], False
raise RuntimeError("Unsupported combination of mask types.")
def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
recv_tensor, recv_src,
cp_group, batch_p2p_comm):
......@@ -878,7 +917,6 @@ class _SplitAlongDim(torch.autograd.Function):
return torch.cat(grad_outputs, dim = split_dim), None, None
class UnfusedDotProductAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms
BMM1 -> softmax + dropout -> BMM2
......@@ -921,7 +959,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""core attention fprop"""
"""Unfused attention fprop"""
assert (qkv_layout in QKVLayouts
), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
......@@ -932,9 +970,6 @@ class UnfusedDotProductAttention(torch.nn.Module):
# convert to sbhd and use sbhd implementation for now
query_layer, key_layer, value_layer = [x.transpose(0, 1)
for x in [query_layer, key_layer, value_layer]]
assert (
attn_mask_type in AttnMaskTypes
), f"attn_mask_type {attn_mask_type} not supported"
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
......@@ -1003,10 +1038,13 @@ class UnfusedDotProductAttention(torch.nn.Module):
+ core_attention_bias).view(-1, output_size[2], output_size[3])
matmul_result /= scale
elif core_attention_bias_type == "post_scale_bias":
assert core_attention_bias is not None, "core_attention_bias should not be None!"
assert (core_attention_bias.shape == torch.Size([1, *output_size[1:]])
), "core_attention_bias must be in [1, h, sq, skv] shape!"
elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
if core_attention_bias_type == "post_scale_bias":
assert core_attention_bias is not None, "core_attention_bias should not be None!"
assert (core_attention_bias.shape == torch.Size([1, *output_size[1:]])
), "core_attention_bias must be in [1, h, sq, skv] shape!"
if core_attention_bias_type == "alibi":
core_attention_bias = get_alibi(output_size[1], output_size[2], output_size[3])
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
......@@ -1016,7 +1054,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
)
matmul_result = (matmul_result.view(
output_size[0], output_size[1], output_size[2], output_size[3])
+ core_attention_bias).view(-1, output_size[2], output_size[3])
+ core_attention_bias).view(-1, output_size[2], output_size[3]).to(
dtype=query_layer.dtype)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
......@@ -1173,7 +1212,6 @@ def _get_qkv_layout(
check_last_two_dims_offsets_kv = all(i * last_two_dims_size == x.storage_offset()
for i, x in enumerate([k, v]))
qkv_layout = None
if (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv
and check_last_two_dims_offsets_qkv
and not check_last_dim_offsets_qkv):
......@@ -1297,7 +1335,7 @@ class FlashAttention(torch.nn.Module):
for x in [query_layer, key_layer, value_layer]
]
if attn_mask_type == 'padding':
if 'padding' in attn_mask_type:
assert not context_parallel, "Padding mask not supported with context parallelism."
if self.attention_type == "self":
......@@ -1305,15 +1343,31 @@ class FlashAttention(torch.nn.Module):
max_seqlen_q == max_seqlen_kv
), "Maximum sequence length for Q and KV should be the same."
if self.layer_number == 1:
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask)
if cu_seqlens_q is None:
assert (attention_mask is not None
), "Please provide attention_mask for padding!"
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask)
else:
_cu_seqlens_q = cu_seqlens_q
_indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
_cu_seqlens_kv = _cu_seqlens_q
query_layer_packed, key_layer_packed, value_layer_packed = PackTensors.apply(
_indices_q, query_layer, key_layer, value_layer
)
else:
if self.layer_number == 1:
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask[0])
_cu_seqlens_kv, _indices_kv = get_cu_seqlens_and_indices(attention_mask[1])
if cu_seqlens_q is None or cu_seqlens_kv is None:
assert (attention_mask is not None
), "Please provide attention_mask for padding!"
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(
attention_mask[0])
_cu_seqlens_kv, _indices_kv = get_cu_seqlens_and_indices(
attention_mask[1])
else:
_cu_seqlens_q = cu_seqlens_q
_cu_seqlens_kv = cu_seqlens_kv
_indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
_indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
query_layer_packed = PackTensors.apply(_indices_q, query_layer)
key_layer_packed, value_layer_packed = PackTensors.apply(
_indices_kv, key_layer, value_layer
......@@ -1355,7 +1409,7 @@ class FlashAttention(torch.nn.Module):
self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor,
causal=attn_mask_type=="causal",
causal="causal" in attn_mask_type,
deterministic=self.deterministic
)
else:
......@@ -1367,11 +1421,11 @@ class FlashAttention(torch.nn.Module):
query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor, causal=attn_mask_type=="causal",
softmax_scale=1.0/self.norm_factor, causal="causal" in attn_mask_type,
**fa_optional_forward_kwargs
)
if attn_mask_type == 'padding':
if 'padding' in attn_mask_type:
output = UnpackTensor.apply(_indices_q, batch_size * max_seqlen_q, output)
if qkv_format == 'sbhd':
......@@ -1380,6 +1434,9 @@ class FlashAttention(torch.nn.Module):
elif qkv_format == 'bshd':
# (bs)hd -> bs(hd)
output = output.view(batch_size, max_seqlen_q, -1).contiguous()
elif qkv_format == 'thd':
# thd -> t(hd)
output = output.view(output.shape[0], -1).contiguous()
return output
......@@ -1416,6 +1473,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
@staticmethod
def backward(ctx, d_out):
qkv, out, cu_seqlens = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd:
softmax_lse, rng_state = ctx.aux_ctx_tensors
dqkv = torch.empty_like(qkv)
......@@ -1426,7 +1485,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
d_out, q, k, v, out, softmax_lse, dqkv[:,0], dqkv[:,1], dqkv[:,2],
cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
ctx.dropout_p, ctx.attn_scale, False,
ctx.attn_mask_type == "causal", None, rng_state
"causal" in ctx.attn_mask_type, None, rng_state
)
dqkv = dqkv[..., :d_out.shape[-1]]
else:
......@@ -1438,8 +1497,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias, return dqkv
if ctx.attn_bias_type == "no_bias":
# if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]:
return (None, None, None, dqkv, None, None, None,
None, None, None, None, None, None,
None, None, None, None, None, None)
......@@ -1482,6 +1541,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
@staticmethod
def backward(ctx, d_out):
q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd:
softmax_lse, rng_state = ctx.aux_ctx_tensors
dq = torch.empty_like(q)
......@@ -1493,7 +1554,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
d_out, q, k, v, out, softmax_lse, dq, dkv[:,0], dkv[:,1],
cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv,
ctx.dropout_p, ctx.attn_scale, False,
ctx.attn_mask_type == "causal", None, rng_state
"causal" in ctx.attn_mask_type, None, rng_state
)
dq = dq[..., :d_out.shape[-1]]
dkv = dkv[..., :d_out.shape[-1]]
......@@ -1507,8 +1568,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias, return dqkv
if ctx.attn_bias_type == "no_bias":
# if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]:
return (None, None, None, None, None, dq, dkv, None, None, None,
None, None, None, None, None, None,
None, None, None, None, None, None)
......@@ -1551,6 +1612,8 @@ class FusedAttnFunc(torch.autograd.Function):
@staticmethod
def backward(ctx, d_out):
q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd:
softmax_lse, rng_state = ctx.aux_ctx_tensors
dq = torch.empty_like(q)
......@@ -1563,7 +1626,7 @@ class FusedAttnFunc(torch.autograd.Function):
d_out, q, k, v, out, softmax_lse, dq, dk, dv,
cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv,
ctx.dropout_p, ctx.attn_scale, False,
ctx.attn_mask_type == "causal", None, rng_state
"causal" in ctx.attn_mask_type, None, rng_state
)
dq = dq[..., :d_out.shape[-1]]
dk = dk[..., :d_out.shape[-1]]
......@@ -1578,8 +1641,8 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias, return dqkv
if ctx.attn_bias_type == "no_bias":
# if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]:
return (None, None, None, None, None, dq, dk, dv, None, None, None,
None, None, None, None, None, None,
None, None, None, None, None, None)
......@@ -1602,19 +1665,17 @@ class FusedAttention(torch.nn.Module):
| flash based | no | yes |
| cuDNN based | yes | yes |
| qkv dtype | fp16/bf16 | fp16/bf16 |
| attn_type | self/cross | self |
| attn_type | self/cross | self/cross |
| qkv_layout | | |
| - qkv | qkv_interleaved | qkv_interleaved |
| - (q,kv) | kv_interleaved | |
| - (q,k,v) | sb3hd, bs3hd | sb3hd, bs3hd, sbh3d, bsh3d |
| | sbhd_sb2hd, bshd_bs2hd | sbhd_sb2hd, bshd_bs2hd |
| | bshd_bshd_bshd | sbhd_sbh2d, bshd_bsh2d |
| | | sbhd_sbhd_sbhd, bshd_bshd_bshd |
| mask_type | causal/no_mask | causal |
| bias_type | no_bias/post_scale_bias | no_bias |
| mask_type | causal/padding/no_mask | causal/padding/no_mask |
| bias_type | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias |
| dropout | yes | yes |
| max_seqlen | <=512 | any |
| head_dim | 64 | 64,128 |
| max_seqlen | <=512, multiple of 64 | any, multiple of 64 |
| head_dim | 64 | <=128, multiple of 8 |
| output dtype | fp16/bf16 | fp16/bf16 |
"""
......@@ -1624,6 +1685,8 @@ class FusedAttention(torch.nn.Module):
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
attention_type: str = "self",
layer_number: Optional[int] = None,
deterministic: bool = False,
) -> None:
super().__init__()
......@@ -1634,6 +1697,22 @@ class FusedAttention(torch.nn.Module):
self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "0") == "1"
and _flash_attn_2_available
and get_device_compute_capability() == (9, 0))
self.layer_number = 1 if layer_number is None else layer_number
if deterministic:
# workspace optimization path is deterministic
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
# CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
# - unset: enables workspace optimization when required workspace is <= 256MB
# or when bias gradient needs to be computed
# - n: enables workspace optimization when required workspace is <= n bytes
# - -1: enables workspace optimization always
# - 0: disables workspace optimization always
if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
def forward(
self,
......@@ -1644,6 +1723,7 @@ class FusedAttention(torch.nn.Module):
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
fused_attention_backend:
tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
core_attention_bias_type: str = "no_bias",
......@@ -1668,6 +1748,10 @@ class FusedAttention(torch.nn.Module):
), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
assert (
qkv_format != 'thd'
), 'FusedAttention does not support qkv_format = thd!'
if qkv_format in ['sbhd', 'bshd']:
if qkv_format == 'sbhd':
batch_size, max_seqlen_q, max_seqlen_kv = (
......@@ -1675,31 +1759,48 @@ class FusedAttention(torch.nn.Module):
if qkv_format == 'bshd':
batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
if cu_seqlens_q is None:
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device)
if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange(
0,
(batch_size + 1) * max_seqlen_kv,
step=max_seqlen_kv,
dtype=torch.int32,
device=key_layer.device)
if qkv_format == 'thd':
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_q = seqlens_q.max().item()
max_seqlen_kv = seqlens_kv.max().item()
if 'padding' in attn_mask_type:
global _cu_seqlens_q, _cu_seqlens_kv
if (cu_seqlens_q is not None and cu_seqlens_kv is not None):
# use cu_seqlens when both cu_seqlens and attention_mask are present
if self.layer_number == 1:
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
elif attention_mask is not None:
if self.attention_type == "self":
if self.layer_number == 1:
_cu_seqlens_q = get_cu_seqlens(attention_mask)
_cu_seqlens_kv = _cu_seqlens_q
else:
if self.layer_number == 1:
_cu_seqlens_q = get_cu_seqlens(attention_mask[0])
_cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
else:
raise Exception("Please provide attention_mask or cu_seqlens for padding!")
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
else:
if self.layer_number == 1:
if cu_seqlens_q is None:
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device)
if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange(
0,
(batch_size + 1) * max_seqlen_kv,
step=max_seqlen_kv,
dtype=torch.int32,
device=key_layer.device)
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
else:
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
qkv_dtype = TE_DType[query_layer.dtype]
use_FAv2_bwd = (self.use_FAv2_bwd
and (core_attention_bias_type == "no_bias")
and (fused_attention_backend
== tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen))
with self.attention_dropout_ctx():
......@@ -1733,7 +1834,7 @@ class DotProductAttention(torch.nn.Module):
.. note::
Argument :attr:`attention_mask` in the `forward` call is only used when
:attr:`self_attn_mask_type` includes `"padding"` or `"arbitrary"`.
:attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
.. warning::
......@@ -1759,18 +1860,20 @@ class DotProductAttention(torch.nn.Module):
attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention.
attn_mask_type: str, default = `causal`
type of attention mask passed into softmax operation, options are "`causal`",
"`padding`", "`arbitrary`", "`no_mask`". For the "`causal`" mask,
TransformerEngine calculates and applies an upper triangular mask to
the softmax input. An "`arbitrary`" mask is an arbitrary user defined mask
broadcastable to the shape of softmax input. The "`padding`" mask is used
for providing locations of padded tokens in the batch, which should be of
the shape [batch_size, 1, 1, seq_len]. No mask is applied for the "`no_mask`"
option. For the `"arbitrary"` and `"padding"` mask types, the argument
:attr:`attention_mask` must be passed into `forward` call. The "`causal`"
mask can also be applied in conjunction with "`padding`" mask by passing
in multiple mask type as a comma separated string, for example,
`attn_mask_type="causal,padding"`.
type of attention mask passed into softmax operation, options are "`no_mask`",
"`padding`", "`causal`", "`padding,causal`", "`causal,padding`", and
"`arbitrary`", where "`padding,causal`" and "`causal,padding`" are equivalent.
This arg can be overridden by :attr:`attn_mask_type` in the `forward` method.
It is useful for cases involving compilation/tracing, e.g. ONNX export, and the
forward arg is useful for dynamically changing mask types, e.g. a different mask
for training and inference. For "`no_mask`", no attention mask is applied. For
"`causal`" or the causal mask in "`padding,causal`", TransformerEngine calculates
and applies an upper triangular mask to the softmax input. No user input is
needed. For "`padding`" or the padding mask in "`padding,causal`", users need to
provide the locations of padded tokens either via :attr:`cu_seqlens_q` and
:attr:`cu_seqlens_kv` in the shape of [batch_size + 1] or :attr:`attention_mask`
in the shape [batch_size, 1, 1, max_seq_len]. For the "`arbitrary`" mask, users
need to provide a mask that is broadcastable to the shape of softmax input.
attention_type: str, default = `self`
type of attention, either "`self`" and "`cross`".
layer_number: int, default = `None`
......@@ -1786,12 +1889,6 @@ class DotProductAttention(torch.nn.Module):
have different lengths. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `_get_qkv_layout` to gain the layout information.
attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
type of attention mask passed into softmax operation. Overridden by
:attr:`attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different
mask for training and inference. The init arg is useful for cases
involving compilation/tracing, e.g. ONNX export.
Parallelism parameters
----------------------
......@@ -1833,6 +1930,9 @@ class DotProductAttention(torch.nn.Module):
super().__init__()
self.qkv_format = qkv_format
attn_mask_type = attn_mask_type.replace(",","_")
if attn_mask_type == "causal_padding":
attn_mask_type = "padding_causal"
self.attn_mask_type = attn_mask_type
self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group
......@@ -1900,9 +2000,11 @@ class DotProductAttention(torch.nn.Module):
# Instantiating three types since use of flash-attn and FusedAttention
# might be ruled out due to forward inputs.
if self.use_fused_attention:
self.fused_attention = FusedAttention(
norm_factor, **attn_kwargs,
attention_type=attention_type)
self.fused_attention = FusedAttention(norm_factor,
attention_type=attention_type,
layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs)
self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number)
......@@ -2016,9 +2118,13 @@ class DotProductAttention(torch.nn.Module):
Key tensor.
value_layer : torch.Tensor
Value tensor.
attention_mask : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Boolean tensor used to mask out softmax input when not using flash-attn.
Can be a tuple of 2 masks for cross attention with padding masks.
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensor(s) used to mask out attention softmax input.
It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be
a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
qkv_format: str, default = `None`
If provided, overrides :attr:`qkv_format` from initialization.
cu_seqlens_q: Optional[torch.Tensor], default = `None`
......@@ -2027,17 +2133,19 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
with shape [batch_size + 1] and dtype torch.int32.
attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `None`
type of attention mask passed into softmax operation.
attn_mask_type: {`no_mask`, `padding`, `causal`, `padding,causal`, `causal,padding`,
`arbitrary`}, default = `None`. Type of attention mask passed into
softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
checkpoint_core_attention : bool, default = `False`
If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`}
Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T
Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
It should be 'None' for 'no_bias' and 'alibi' bias types.
fast_zero_fill: bool, default = `True`
Whether to use the fast path to set output tensors to 0 or not.
"""
......@@ -2051,9 +2159,16 @@ class DotProductAttention(torch.nn.Module):
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
else:
attn_mask_type = attn_mask_type.replace(",","_")
if attn_mask_type == "causal_padding":
attn_mask_type = "padding_causal"
assert (attn_mask_type in AttnMaskTypes
), f"Attention mask type {attn_mask_type} is not supported!"
if qkv_format is None:
qkv_format = self.qkv_format
attn_mask_type, causal_mask = _unpack_attn_mask_type(attn_mask_type)
assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition
......@@ -2136,7 +2251,7 @@ class DotProductAttention(torch.nn.Module):
use_flash_attention = False
if (_flash_attn_2_1_plus
and causal_mask
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv):
warnings.warn(
"Disabling the use of FlashAttention since version 2.1+ has changed its behavior "
......@@ -2156,19 +2271,15 @@ class DotProductAttention(torch.nn.Module):
# Filter: Attention mask type.
# attn_mask_type(s) | supported backends
# ------------------------------------------------
# no_mask | All
# padding | UnfusedDotProductAttention, FlashAttention, FusedAttention
# causal | All
# padding | UnfusedDotProductAttention, FlashAttention
# padding + causal | FlashAttention, FusedAttention
# arbitrary | UnfusedDotProductAttention
# no_mask | All
# causal + padding | FlashAttention
#
if attn_mask_type == "arbitrary":
use_flash_attention = False
use_fused_attention = False
elif attn_mask_type == "padding" and causal_mask:
assert use_flash_attention, "No attention backend available for causal + padding masks."
elif attn_mask_type == "padding":
use_fused_attention = False
if use_fused_attention:
fused_attention_backend = tex.get_fused_attn_backend(
......@@ -2178,23 +2289,28 @@ class DotProductAttention(torch.nn.Module):
AttnBiasType[core_attention_bias_type],
AttnMaskType[attn_mask_type],
self.attention_dropout,
max_seqlen_q, max_seqlen_kv,
query_layer.shape[-1])
query_layer.shape[-2], # num_attn_heads
key_layer.shape[-2], # num_gqa_groups
max_seqlen_q,
max_seqlen_kv,
query_layer.shape[-1], # head_dim
)
# DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = (fused_attention_backend in
[FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]])
use_fused_attention = (use_fused_attention
and is_backend_avail
and self.num_gqa_groups == self.num_attention_heads)
if (self.deterministic
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]):
use_fused_attention = False
warnings.warn(
"Disabling usage of FusedAttention since the FusedAttention"
"backend does not support deterministic exection."
)
and is_backend_avail)
# Select FusedAttention on sm90 and FlashAttention on others for performance
if (use_flash_attention
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]):
if self.device_compute_capability == (9, 0):
use_flash_attention = False
if use_flash_attention:
if _NVTE_DEBUG:
print("[DotProductAttention]: using flash-attn",_flash_attn_version)
return self.flash_attention(query_layer,
key_layer,
value_layer,
......@@ -2212,6 +2328,9 @@ class DotProductAttention(torch.nn.Module):
), "Context parallelism is only implemented with Flash Attention!"
if use_fused_attention:
if _NVTE_DEBUG:
print("[DotProductAttention]: using cuDNN fused attention (backend "
+ str(int(fused_attention_backend)) + ")")
if checkpoint_core_attention:
return self._checkpointed_attention_forward(self.fused_attention,
query_layer,
......@@ -2221,6 +2340,7 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_q = cu_seqlens_q,
cu_seqlens_kv = cu_seqlens_kv,
attn_mask_type = attn_mask_type,
attention_mask = attention_mask,
fused_attention_backend = fused_attention_backend,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
......@@ -2230,11 +2350,14 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_q = cu_seqlens_q,
cu_seqlens_kv = cu_seqlens_kv,
attn_mask_type = attn_mask_type,
attention_mask = attention_mask,
fused_attention_backend = fused_attention_backend,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
fast_zero_fill = fast_zero_fill)
if _NVTE_DEBUG:
print("[DotProductAttention]: using unfused DPA")
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.unfused_attention,
......@@ -2267,8 +2390,8 @@ class MultiheadAttention(torch.nn.Module):
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`attn_mask_type` is set to `"causal"`.
Argument :attr:`attention_mask` in the `forward` call is only used when
:attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
Parameters
----------
......@@ -2295,7 +2418,8 @@ class MultiheadAttention(torch.nn.Module):
layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules are
concatenated to form a transformer block.
attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal' 'arbitrary'},
default = `causal`
type of attention mask passed into softmax operation. Overridden by
:attr:`attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different
......@@ -2646,16 +2770,22 @@ class MultiheadAttention(torch.nn.Module):
.. note::
Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
is set to `"causal"`.
Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
includes `"padding"` or `"arbitrary"`.
Parameters
----------
hidden_states : torch.Tensor
Input tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input.
attn_mask_type: {'causal', 'padding', 'no_mask', arbitrary}, default = `None`
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensor(s) used to mask out attention softmax input.
It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be
a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
default = `None`
type of attention mask passed into softmax operation.
encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
......@@ -2682,9 +2812,10 @@ class MultiheadAttention(torch.nn.Module):
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`}
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T
Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
It should be 'None' for 'no_bias' and 'alibi' bias types.
fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use.
"""
......@@ -2693,10 +2824,11 @@ class MultiheadAttention(torch.nn.Module):
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
if attn_mask_type == "padding" and attention_mask is not None:
assert (
attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor"
if "padding" in attn_mask_type and attention_mask is not None:
for i,_ in enumerate(attention_mask):
assert (
attention_mask[i].dtype == torch.bool
), "Attention mask must be in boolean type!"
assert (core_attention_bias_type in AttnBiasTypes
), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
......
......@@ -22,11 +22,11 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16,
}
AttnMaskTypes = ("causal", "padding", "arbitrary", "no_mask")
AttnMaskTypes = ("causal", "padding", "padding_causal", "arbitrary", "no_mask")
AttnTypes = ("self", "cross")
AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias")
AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias", "alibi")
QKVLayouts = (
"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd",
......
......@@ -33,9 +33,6 @@ TORCH_DType = {
}
QKVLayout = {
"not_interleaved": NVTE_QKV_Layout.NVTE_NOT_INTERLEAVED,
"qkv_interleaved": NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED,
"kv_interleaved": NVTE_QKV_Layout.NVTE_KV_INTERLEAVED,
"sb3hd": NVTE_QKV_Layout.NVTE_SB3HD,
"sbh3d": NVTE_QKV_Layout.NVTE_SBH3D,
"sbhd_sb2hd": NVTE_QKV_Layout.NVTE_SBHD_SB2HD,
......@@ -57,12 +54,14 @@ AttnBiasType = {
"no_bias": NVTE_Bias_Type.NVTE_NO_BIAS,
"pre_scale_bias": NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS,
"post_scale_bias": NVTE_Bias_Type.NVTE_POST_SCALE_BIAS,
"alibi": NVTE_Bias_Type.NVTE_ALIBI,
}
AttnMaskType = {
"no_mask": NVTE_Mask_Type.NVTE_NO_MASK,
"padding": NVTE_Mask_Type.NVTE_PADDING_MASK,
"causal": NVTE_Mask_Type.NVTE_CAUSAL_MASK,
"padding_causal": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK,
}
FusedAttnBackend = {
......@@ -76,84 +75,6 @@ BACKEND_F16m512_FP8_THREADS_PER_CTA = 128
BACKEND_F16arb_ELTS_PER_THREADS = 16
def check_tensor(x: torch.Tensor):
"""Check tensor properties."""
assert (x.is_cuda and x.is_contiguous()
), "Tensor should be a GPU tensor and contiguous."
def check_qkv(qkv: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties."""
check_tensor(qkv)
assert (qkv.dtype is dtype
and qkv.dim() == 4
and qkv.shape[1] == 3
), """QKV should be in [total_seqs, 3, num_heads, head_dim] shape
and {dtype} dtype."""
def check_q(q: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties."""
check_tensor(q)
assert (q.dtype is dtype
and q.dim() == 3
), """Q should be in [total_seqs, num_heads, head_dim] shape
and {dtype} dtype."""
def check_kv(kv: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties."""
check_tensor(kv)
assert (kv.dtype is dtype
and kv.dim() == 4
and kv.shape[1] == 2
), """KV should be in [total_seqs, 2, num_heads, head_dim] shape
and {dtype} dtype."""
def check_o(o: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties."""
check_tensor(o)
assert (o.dtype is dtype
and o.dim() == 3
), """O and dO should be in [total_seqs, num_heads, head_dim] shape
and {dtype} dtype."""
def check_stats(stats: torch.Tensor, b: int, h: int, s: int):
"""Check tensor properties."""
check_tensor(stats)
assert (stats.dtype is torch.float32
and stats.dim() == 4
and stats.shape == torch.Size([b, h, s, 1])
), """M and ZInv should be in [batch_size, num_heads, max_seqlen_q, 1]
shape and float32 dtype."""
def check_cu_seqlens(cu_seqlens: torch.Tensor):
"""Check tensor properties."""
check_tensor(cu_seqlens)
assert (cu_seqlens.dtype is torch.int32
and cu_seqlens.dim() == 1
), """cu_seqlens should be in [batch_size +1] shape and int32 dtype."""
def check_scalar(scalar: torch.Tensor):
"""Check tensor properties."""
check_tensor(scalar)
assert (scalar.dtype is torch.float32
and scalar.dim() <= 1
and scalar.numel() == 1
), "amax/scale/descale tensors should be scalars in float32 dtype."
def check_rng_state(rng_state: torch.Tensor):
"""Check tensor properties."""
check_tensor(rng_state)
assert (rng_state.dtype is torch.int64
and rng_state.numel() == 2
), "rng_state should be [seed, offset] and in int64 dtype."
def fused_attn_fwd_qkvpacked(
is_training: bool,
max_seqlen: int,
......@@ -170,7 +91,7 @@ def fused_attn_fwd_qkvpacked(
attn_scale: float = None,
dropout: float = 0.0,
fast_zero_fill: bool = True,
qkv_layout: str = "qkv_interleaved",
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
rng_gen: torch.Generator = None,
......@@ -188,8 +109,7 @@ def fused_attn_fwd_qkvpacked(
cu_seqlens: torch.Tensor
cumulative sequence lengths for QKV; shape [batch_size + 1]
qkv: torch.Tensor
input tensor QKV;
shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
input tensor QKV; shape 3hd or h3d (see `qkv_layout` for details)
qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
......@@ -216,12 +136,12 @@ def fused_attn_fwd_qkvpacked(
fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
qkv_layout: str, default = "sbh3d"
layout of QKV; {"sbh3d", "sb3hd", "bsh3d", "bs3hd", "th3d", "t3hd"}
attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
......@@ -230,7 +150,7 @@ def fused_attn_fwd_qkvpacked(
----------
o: torch.Tensor
output tensor O, of the attention calculation; same data type as QKV;
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state]
......@@ -257,21 +177,14 @@ def fused_attn_fwd_qkvpacked(
[seed, offset], dtype uint64
"""
check_cu_seqlens(cu_seqlens)
b = cu_seqlens.numel() - 1
qkv_type = TORCH_DType[qkv_dtype]
check_qkv(qkv, qkv_type)
total_seqs = qkv.size(0)
h = qkv.size(2)
d = qkv.size(3)
if attn_scale is None:
d = qkv.size(-1)
attn_scale = 1.0 / math.sqrt(d)
if attn_bias_type != "no_bias":
if attn_bias_type not in ["no_bias", "alibi"]:
assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias."
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi."
h = qkv.size(2) if 'h3d' in qkv_layout else qkv.size(3)
assert (attn_bias.shape == torch.Size([1, h, max_seqlen, max_seqlen])
), "attn_bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (attn_bias.dtype == qkv.dtype
......@@ -304,16 +217,10 @@ def fused_attn_fwd_qkvpacked(
), "amax_s is required as an input for FP8 fused attention."
assert (amax_o is not None
), "amax_o is required as an input for FP8 fused attention."
check_scalar(d_scale_qkv)
check_scalar(q_scale_s)
check_scalar(q_scale_o)
check_scalar(amax_s)
check_scalar(amax_o)
# execute kernel
output_tensors = tex.fused_attn_fwd_qkvpacked(
b, max_seqlen, total_seqs, h, d,
is_training, attn_scale, dropout, fast_zero_fill,
max_seqlen, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, qkv_dtype,
d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias,
......@@ -345,7 +252,7 @@ def fused_attn_bwd_qkvpacked(
attn_scale: float = None,
dropout: float = 0.0,
fast_zero_fill: bool = True,
qkv_layout: str = "qkv_interleaved",
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[Union[torch.Tensor, None], ...]:
......@@ -359,14 +266,13 @@ def fused_attn_bwd_qkvpacked(
cu_seqlens: torch.Tensor
cumulative sequence lengths for QKV; shape [batch_size + 1]
qkv: torch.Tensor
input tensor QKV;
shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
input tensor QKV; shape 3hd or h3d (see `qkv_layout` for details)
o: torch.Tensor
input tensor O (output of forward);
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
d_o: torch.Tensor
input tensor dO (gradient of O);
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor]
......@@ -401,12 +307,12 @@ def fused_attn_bwd_qkvpacked(
fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "qkv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
qkv_layout: str, default = "sbh3d"
layout of QKV; {"sbh3d", "sb3hd", "bsh3d", "bs3hd", "th3d", "t3hd"}
attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
Returns
----------
......@@ -417,18 +323,8 @@ def fused_attn_bwd_qkvpacked(
or "post_scale_bias"; same data type and shape as Bias
"""
check_cu_seqlens(cu_seqlens)
b = cu_seqlens.numel() - 1
qkv_type = TORCH_DType[qkv_dtype]
check_qkv(qkv, qkv_type)
check_o(o, qkv_type)
check_o(d_o, qkv_type)
total_seqs = qkv.size(0)
h = qkv.size(2)
d = qkv.size(3)
if attn_scale is None:
d = qkv.size(-1)
attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
......@@ -437,8 +333,6 @@ def fused_attn_bwd_qkvpacked(
if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]:
assert (len(aux_ctx_tensors) >= 1
), "aux_ctx_tensors must contain rng_state as its last element."
rng_state = aux_ctx_tensors[-1]
check_rng_state(rng_state)
if fused_attention_backend == FusedAttnBackend["FP8"]:
assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention."
......@@ -452,34 +346,17 @@ def fused_attn_bwd_qkvpacked(
assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention."
assert (len(aux_ctx_tensors) == 3
), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention."
check_scalar(d_scale_qkv)
check_scalar(d_scale_s)
check_scalar(d_scale_o)
check_scalar(d_scale_do)
check_scalar(q_scale_s)
check_scalar(q_scale_dp)
check_scalar(q_scale_dqkv)
check_scalar(amax_dp)
check_scalar(amax_dqkv)
m, z_inv = aux_ctx_tensors[:2]
check_stats(m, b, h, max_seqlen)
check_stats(z_inv, b, h, max_seqlen)
# execute kernel
output_tensors = tex.fused_attn_bwd_qkvpacked(
b, max_seqlen, total_seqs, h, d,
attn_scale, dropout, fast_zero_fill,
max_seqlen, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, o, d_o, qkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
)
if attn_bias_type == "no_bias":
# return d_qkv when attn_bias_type is no_bias
return output_tensors
# otherwise return (d_qkv, d_bias)
return output_tensors[0], output_tensors[1]
return output_tensors
def fused_attn_fwd_kvpacked(
......@@ -501,7 +378,7 @@ def fused_attn_fwd_kvpacked(
attn_scale: float = None,
dropout: float = 0.0,
fast_zero_fill: bool = True,
qkv_layout: str = "kv_interleaved",
qkv_layout: str = "sbhd_sbh2d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
rng_gen: torch.Generator = None,
......@@ -524,12 +401,9 @@ def fused_attn_fwd_kvpacked(
cu_seqlens_kv: torch.Tensor
cumulative sequence lengths for KV; shape [batch_size + 1]
q: torch.Tensor
input tensor Q;
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
input tensor Q; shape thd, sbhd or bshd (see `qkv_layout` for details)
kv: torch.Tensor
packed input tensor KV;
shape [total_seqs_kv, 2, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1]
packed input tensor KV; shape 2hd or h2d (see `qkv_layout` for details)
qkv_dtype: tex.DType
data type of Q and KV; in tex.DType, not torch.dtype
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
......@@ -556,12 +430,13 @@ def fused_attn_fwd_kvpacked(
fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "kv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
qkv_layout: str, default = "sbhd_sbh2d"
layout of QKV;
{"sbhd_sbh2d", "sbhd_sb2hd", "bshd_bsh2d", "bshd_bs2hd", "thd_th2d", "thd_t2hd"}
attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
......@@ -570,7 +445,7 @@ def fused_attn_fwd_kvpacked(
----------
o: torch.Tensor
output tensor O, of the attention calculation; same data type as QKV;
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state]
......@@ -597,29 +472,14 @@ def fused_attn_fwd_kvpacked(
[seed, offset], dtype uint64
"""
check_cu_seqlens(cu_seqlens_q)
check_cu_seqlens(cu_seqlens_kv)
assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel()
), "cu_seqlens_q and cu_seqlens_kv must have the same length."
b = cu_seqlens_q.numel() - 1
qkv_type = TORCH_DType[qkv_dtype]
check_q(q, qkv_type)
check_kv(kv, qkv_type)
assert (q.size(1) == kv.size(2)
and q.size(2) == kv.size(3)
), "Q and KV must have the same num_heads and head_dim."
total_seqs_q = q.size(0)
total_seqs_kv = kv.size(0)
h = q.size(1)
d = q.size(2)
if attn_scale is None:
d = q.size(-1)
attn_scale = 1.0 / math.sqrt(d)
if attn_bias_type != "no_bias":
if attn_bias_type not in ["no_bias", "alibi"]:
assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias."
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi."
h = q.size(2)
assert (attn_bias.shape == torch.Size([1, h, max_seqlen_q, max_seqlen_kv])
), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (attn_bias.dtype == q.dtype
......@@ -644,8 +504,7 @@ def fused_attn_fwd_kvpacked(
# execute kernel
output_tensors = tex.fused_attn_fwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d,
is_training, attn_scale, dropout, fast_zero_fill,
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype,
d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o,
......@@ -680,7 +539,7 @@ def fused_attn_bwd_kvpacked(
attn_scale: float = None,
dropout: float = 0.0,
fast_zero_fill: bool = True,
qkv_layout: str = "kv_interleaved",
qkv_layout: str = "sbhd_sbh2d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[Union[torch.Tensor, None], ...]:
......@@ -699,18 +558,15 @@ def fused_attn_bwd_kvpacked(
cu_seqlens_kv: torch.Tensor
cumulative sequence lengths for KV; shape [batch_size + 1]
q: torch.Tensor
input tensor Q;
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
input tensor Q; shape thd, sbhd or bshd (see `qkv_layout` for details)
kv: torch.Tensor
packed input tensor KV;
shape [total_seqs_kv, 2, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1]
packed input tensor KV; shape h2d or 2hd (see `qkv_layout` for details)
o: torch.Tensor
input tensor O (output of forward);
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
d_o: torch.Tensor
input tensor dO (gradient of O);
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor]
......@@ -746,12 +602,13 @@ def fused_attn_bwd_kvpacked(
fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "kv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
qkv_layout: str, default = "sbhd_sbh2d"
layout of QKV;
{"sbhd_sbh2d", "sbhd_sb2hd", "bshd_bsh2d", "bshd_bs2hd", "thd_th2d", "thd_t2hd"}
attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
Returns
----------
......@@ -764,26 +621,8 @@ def fused_attn_bwd_kvpacked(
or "post_scale_bias"; same data type and shape as Bias
"""
check_cu_seqlens(cu_seqlens_q)
check_cu_seqlens(cu_seqlens_kv)
assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel()
), "cu_seqlens_q and cu_seqlens_kv must have the same length."
b = cu_seqlens_q.numel() - 1
qkv_type = TORCH_DType[qkv_dtype]
check_q(q, qkv_type)
check_kv(kv, qkv_type)
check_o(o, qkv_type)
check_o(d_o, qkv_type)
assert (q.size(1) == kv.size(2)
and q.size(2) == kv.size(3)
), "Q and KV must have the same num_heads and head_dim."
total_seqs_q = q.size(0)
total_seqs_kv = q.size(0)
h = q.size(1)
d = q.size(2)
if attn_scale is None:
d = q.size(-1)
attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
......@@ -792,8 +631,6 @@ def fused_attn_bwd_kvpacked(
if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]:
assert (len(aux_ctx_tensors) >= 1
), "aux_ctx_tensors must contain rng_state as its last element."
rng_state = aux_ctx_tensors[-1]
check_rng_state(rng_state)
if fused_attention_backend == FusedAttnBackend["FP8"]:
assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention."
......@@ -807,34 +644,18 @@ def fused_attn_bwd_kvpacked(
assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention."
assert (len(aux_ctx_tensors) == 3
), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention."
check_scalar(d_scale_qkv)
check_scalar(d_scale_s)
check_scalar(d_scale_o)
check_scalar(d_scale_do)
check_scalar(q_scale_s)
check_scalar(q_scale_dp)
check_scalar(q_scale_dqkv)
check_scalar(amax_dp)
check_scalar(amax_dqkv)
m, z_inv = aux_ctx_tensors[:2]
check_stats(m, b, h, max_seqlen_q)
check_stats(z_inv, b, h, max_seqlen_q)
# execute kernel
output_tensors = tex.fused_attn_bwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d,
attn_scale, dropout, fast_zero_fill,
max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
)
if attn_bias_type == "no_bias":
# return (d_q, d_kv) when attn_bias_type is no_bias
return output_tensors
# otherwise return (d_q, d_kv), d_bias
return output_tensors[:2], output_tensors[2]
return output_tensors
def fused_attn_fwd(
is_training: bool,
......@@ -881,23 +702,11 @@ def fused_attn_fwd(
cu_seqlens_kv: torch.Tensor
cumulative sequence lengths for K and V; shape [batch_size + 1]
q: torch.Tensor
input tensor Q;
shape [total_seqs_q, num_heads, head_dim],
where total_seqs_q = cu_seqlens_q[-1],
or [batch_size, seqlen_q, num_heads, head_dim],
or [seqlen_q, batch_size, num_heads, head_dim]
input tensor Q; shape sbhd, bshd or thd (see `qkv_layout` for details)
k: torch.Tensor
input tensor K;
shape [total_seqs_kv, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1],
or [batch_size, seqlen_kv, num_heads, head_dim],
or [seqlen_kv, batch_size, num_heads, head_dim]
input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details)
v: torch.Tensor
input tensor V;
shape [total_seqs_kv, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1],
or [batch_size, seqlen_kv, num_heads, head_dim],
or [seqlen_kv, batch_size, num_heads, head_dim]
input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details)
qkv_dtype: tex.DType
data type of Q, K and V; in tex.DType, not torch.dtype
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
......@@ -930,9 +739,9 @@ def fused_attn_fwd(
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"}
attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
......@@ -968,19 +777,14 @@ def fused_attn_fwd(
[seed, offset], dtype uint64
"""
check_cu_seqlens(cu_seqlens_q)
check_cu_seqlens(cu_seqlens_kv)
assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel()
), "cu_seqlens_q and cu_seqlens_kv must have the same length."
h = q.shape[-2]
d = q.shape[-1]
if attn_scale is None:
d = q.size(-1)
attn_scale = 1.0 / math.sqrt(d)
if attn_bias_type != "no_bias":
if attn_bias_type not in ["no_bias", "alibi"]:
assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias."
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi."
h = q.size(2)
assert (attn_bias.shape == torch.Size([1, h, max_seqlen_q, max_seqlen_kv])
), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (attn_bias.dtype == q.dtype
......@@ -1061,23 +865,11 @@ def fused_attn_bwd(
cu_seqlens_kv: torch.Tensor
cumulative sequence lengths for K and V; shape [batch_size + 1]
q: torch.Tensor
input tensor Q;
shape [total_seqs_q, num_heads, head_dim],
where total_seqs_q = cu_seqlens_q[-1],
or [batch_size, seqlen_q, num_heads, head_dim],
or [seqlen_q, batch_size, num_heads, head_dim]
input tensor Q; shape sbhd, bshd or thd (see `qkv_layout` for details)
k: torch.Tensor
input tensor K;
shape [total_seqs_kv, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1],
or [batch_size, seqlen_kv, num_heads, head_dim],
or [seqlen_kv, batch_size, num_heads, head_dim]
input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details)
v: torch.Tensor
input tensor V;
shape [total_seqs_kv, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1],
or [batch_size, seqlen_kv, num_heads, head_dim],
or [seqlen_kv, batch_size, num_heads, head_dim]
input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details)
o: torch.Tensor
input tensor O (output of forward); same data type as Q, K and V;
same shape as Q
......@@ -1125,9 +917,9 @@ def fused_attn_bwd(
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"}
attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
Returns
----------
......@@ -1142,15 +934,8 @@ def fused_attn_bwd(
or "post_scale_bias"; same data type and shape as Bias
"""
check_cu_seqlens(cu_seqlens_q)
check_cu_seqlens(cu_seqlens_kv)
assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel()
), "cu_seqlens_q and cu_seqlens_kv must have the same length."
b = cu_seqlens_q.numel() - 1
h = q.shape[-2]
d = q.shape[-1]
if attn_scale is None:
d = q.size(-1)
attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
......@@ -1159,8 +944,6 @@ def fused_attn_bwd(
if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]:
assert (len(aux_ctx_tensors) >= 1
), "aux_ctx_tensors must contain rng_state as its last element."
rng_state = aux_ctx_tensors[-1]
check_rng_state(rng_state)
if fused_attention_backend == FusedAttnBackend["FP8"]:
assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention."
......@@ -1174,18 +957,6 @@ def fused_attn_bwd(
assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention."
assert (len(aux_ctx_tensors) == 3
), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention."
check_scalar(d_scale_qkv)
check_scalar(d_scale_s)
check_scalar(d_scale_o)
check_scalar(d_scale_do)
check_scalar(q_scale_s)
check_scalar(q_scale_dp)
check_scalar(q_scale_dqkv)
check_scalar(amax_dp)
check_scalar(amax_dqkv)
m, z_inv = aux_ctx_tensors[:2]
check_stats(m, b, h, max_seqlen_q)
check_stats(z_inv, b, h, max_seqlen_q)
# execute kernel
output_tensors = tex.fused_attn_bwd(
......@@ -1196,4 +967,4 @@ def fused_attn_bwd(
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
)
return tuple(output_tensors)
return output_tensors
......@@ -13,12 +13,13 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim);
float p_dropout,
size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv,
size_t head_dim);
std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d, bool is_training,
size_t max_seqlen, bool is_training,
float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
......@@ -36,8 +37,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d, float attn_scale,
size_t max_seqlen, float attn_scale,
float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
......@@ -59,9 +59,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
c10::optional<at::Tensor> amax_dQKV);
std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d, bool is_training,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
......@@ -81,10 +79,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d, float attn_scale,
float p_dropout, bool set_zero,
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
......
......@@ -16,13 +16,16 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim) {
float p_dropout,
size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv,
size_t head_dim) {
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype),
qkv_layout, bias_type, attn_mask_type,
p_dropout, max_seqlen_q, max_seqlen_kv, head_dim);
qkv_layout, bias_type, attn_mask_type, p_dropout,
num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv, head_dim);
return fused_attention_backend;
}
......@@ -87,9 +90,8 @@ at::PhiloxCudaState init_philox_state(
// fused attention FWD with packed QKV
std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
size_t max_seqlen, bool is_training, float attn_scale,
float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
......@@ -104,16 +106,27 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t rng_elts_per_thread) {
using namespace transformer_engine;
auto qkv_sizes = QKV.sizes().vec();
std::vector<size_t> qkv_shape{qkv_sizes.begin(), qkv_sizes.end()};
std::vector<size_t> q_shape;
for (auto i : qkv_shape) {
if (i != 3) {
q_shape.push_back(i);
}
}
std::vector<int64_t> o_shape{q_shape.begin(), q_shape.end()};
// create output tensor O
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto O = torch::empty({static_cast<int64_t>(total_seqs),
static_cast<int64_t>(h), static_cast<int64_t>(d)}, options);
auto O = torch::empty(o_shape, options);
// construct NVTE tensors
TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
if (set_zero && (h * d % block_size == 0)) {
auto h = q_shape[q_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0)) {
mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
O.fill_(0);
......@@ -123,32 +136,34 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O";
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d},
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
at::Tensor descale_S = torch::empty_like(scale_S.value());
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d},
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d},
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape,
qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d},
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) {
auto bias_shape = Bias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape,
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
auto bias_sizes = Bias.value().sizes().vec();
std::vector<size_t> bias_shape{bias_sizes.begin(), bias_sizes.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape,
DType::kFloat32, nullptr, nullptr, nullptr);
}
te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1},
auto cu_seqlens_sizes = cu_seqlens.sizes().vec();
std::vector<size_t> cu_seqlens_shape{cu_seqlens_sizes.begin(), cu_seqlens_sizes.end()};
te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape,
DType::kInt32, nullptr, nullptr, nullptr);
// extract random number generator seed and offset
......@@ -196,8 +211,18 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
// allocate memory for nvte_aux_tensor_pack.tensors
at::Tensor output_tensor;
if (nvte_aux_tensor_pack.size >= 2) {
output_tensor = (i < nvte_aux_tensor_pack.size-1)
? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
if (i < nvte_aux_tensor_pack.size - 2) {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
} else if (i == nvte_aux_tensor_pack.size - 2) {
output_tensor = rng_state;
} else if (i == nvte_aux_tensor_pack.size - 1) {
output_tensor = Bias.value();
}
} else {
output_tensor = (i < nvte_aux_tensor_pack.size-1)
? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state;
}
} else {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
}
......@@ -229,9 +254,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
// fused attention BWD with packed QKV
std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
......@@ -250,12 +273,22 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
c10::optional<at::Tensor> amax_dQKV) {
using namespace transformer_engine;
auto qkv_sizes = QKV.sizes().vec();
std::vector<size_t> qkv_shape{qkv_sizes.begin(), qkv_sizes.end()};
std::vector<size_t> q_shape;
for (auto i : qkv_shape) {
if (i != 3) {
q_shape.push_back(i);
}
}
auto h = q_shape[q_shape.size() - 2];
// create output tensor dQKV
at::Tensor dQKV = torch::empty_like(QKV);
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias;
TensorWrapper te_dBias;
if (bias_type != NVTE_NO_BIAS) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
dBias = torch::empty({1, static_cast<int64_t>(h),
static_cast<int64_t>(max_seqlen),
static_cast<int64_t>(max_seqlen)}, options);
......@@ -266,10 +299,8 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
auto max_tokens = dQKV.size(0);
auto self_2d = dQKV.view({max_tokens, -1});
auto fcd_size = self_2d.size(1);
if (set_zero && (fcd_size % block_size == 0)) {
auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0)) {
mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
dQKV.fill_(0);
......@@ -283,35 +314,33 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV");
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d},
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d},
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs, h, d},
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32,
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32,
nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr());
at::Tensor descale_dP = torch::empty_like(scale_dP.value());
te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_dP.value().data_ptr(), scale_dP.value().data_ptr(),
descale_dP.data_ptr());
te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), {total_seqs, 3, h, d},
qkv_type,
te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, qkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d},
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape,
qkv_type, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d},
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs, h, d},
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), {total_seqs, 3, h, d},
te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape,
qkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
......@@ -330,8 +359,9 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
}
// create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens;
te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1},
auto cu_seqlens_sizes = cu_seqlens.sizes().vec();
std::vector<size_t> cu_seqlens_shape{cu_seqlens_sizes.begin(), cu_seqlens_sizes.end()};
TensorWrapper te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape,
DType::kInt32, nullptr, nullptr, nullptr);
// create workspace
......@@ -385,9 +415,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
// fused attention FWD with packed KV
std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q,
......@@ -405,16 +433,23 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t rng_elts_per_thread) {
using namespace transformer_engine;
auto q_sizes = Q.sizes().vec();
std::vector<size_t> q_shape{q_sizes.begin(), q_sizes.end()};
auto kv_sizes = KV.sizes().vec();
std::vector<size_t> kv_shape{kv_sizes.begin(), kv_sizes.end()};
std::vector<int64_t> o_shape{q_shape.begin(), q_shape.end()};
// create output tensor O
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto O = torch::empty({static_cast<int64_t>(total_seqs_q),
static_cast<int64_t>(h), static_cast<int64_t>(d)}, options);
auto O = torch::empty(o_shape, options);
// construct NVTE tensors
TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
if (set_zero && (h * d % block_size == 0)) {
auto h = q_shape[q_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0)) {
mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
O.fill_(0);
......@@ -424,38 +459,42 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O";
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d},
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d},
te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
at::Tensor descale_S = torch::empty_like(scale_S.value());
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d},
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d},
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d},
te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape,
qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d},
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) {
auto bias_shape = Bias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape,
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
auto bias_sizes = Bias.value().sizes().vec();
std::vector<size_t> bias_shape{bias_sizes.begin(), bias_sizes.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape,
DType::kFloat32, nullptr, nullptr, nullptr);
}
te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), {b+1},
auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec();
std::vector<size_t> cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()};
auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec();
std::vector<size_t> cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()};
te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1},
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr);
// extract rng seed and offset
......@@ -505,8 +544,18 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
// allocate memory for nvte_aux_tensor_pack.tensors
at::Tensor output_tensor;
if (nvte_aux_tensor_pack.size >= 2) {
output_tensor = (i < nvte_aux_tensor_pack.size-1)
? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
if (i < nvte_aux_tensor_pack.size - 2) {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
} else if (i == nvte_aux_tensor_pack.size - 2) {
output_tensor = rng_state;
} else if (i == nvte_aux_tensor_pack.size - 1) {
output_tensor = Bias.value();
}
} else {
output_tensor = (i < nvte_aux_tensor_pack.size-1)
? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state;
}
} else {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
}
......@@ -540,9 +589,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
// fused attention BWD with packed KV
std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q,
......@@ -564,14 +611,28 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
c10::optional<at::Tensor> amax_dQKV) {
using namespace transformer_engine;
auto q_sizes = Q.sizes().vec();
std::vector<size_t> q_shape{q_sizes.begin(), q_sizes.end()};
auto kv_sizes = KV.sizes().vec();
std::vector<size_t> kv_shape{kv_sizes.begin(), kv_sizes.end()};
std::vector<size_t> k_shape;
for (auto i : kv_shape) {
if (i != 2) {
k_shape.push_back(i);
}
}
auto h_q = q_shape[q_shape.size() - 2];
auto h_kv = k_shape[k_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
// create output tensors dQ and dKV
at::Tensor dQ = torch::empty_like(Q);
at::Tensor dKV = torch::empty_like(KV);
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias;
TensorWrapper te_dBias;
if (bias_type != NVTE_NO_BIAS) {
dBias = torch::empty({1, static_cast<int64_t>(h),
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
dBias = torch::empty({1, static_cast<int64_t>(h_q),
static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
......@@ -581,13 +642,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
auto max_tokens_q = dQ.size(0);
auto self_2d_q = dQ.view({max_tokens_q, -1});
auto fcd_size_q = self_2d_q.size(1);
auto max_tokens_kv = dQ.size(0);
auto self_2d_kv = dQ.view({max_tokens_kv, -1});
auto fcd_size_kv = self_2d_kv.size(1);
if (set_zero && (fcd_size_q % block_size == 0) && (fcd_size_kv % block_size == 0)) {
if (set_zero && ((h_q * d)% block_size == 0) && ((h_kv * d)% block_size == 0)) {
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
......@@ -603,13 +658,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV");
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d},
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d},
te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d},
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs_q, h, d},
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr,
scale_S.value().data_ptr(), descale_S.value().data_ptr());
......@@ -617,37 +672,41 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32,
amax_dP.value().data_ptr(), scale_dP.value().data_ptr(),
descale_dP.data_ptr());
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), {total_seqs_q, h, d}, qkv_type,
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, qkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), {total_seqs_kv, 2, h, d}, qkv_type,
te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, qkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d},
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d},
te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape,
qkv_type, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d},
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs_q, h, d},
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), {total_seqs_q, h, d},
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), {total_seqs_kv, 2, h, d},
te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape,
qkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
// create cu_seqlens tensorwrappers
auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec();
std::vector<size_t> cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()};
auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec();
std::vector<size_t> cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()};
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), {b+1},
te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1},
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr);
// convert auxiliary tensors from forward to NVTETensors
......@@ -753,8 +812,8 @@ std::vector<at::Tensor> fused_attn_fwd(
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
auto h = Q.size(-2);
auto d = Q.size(-1);
auto h = q_shape[q_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0)) {
mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
......@@ -792,7 +851,7 @@ std::vector<at::Tensor> fused_attn_fwd(
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
auto bias_sizes = Bias.value().sizes().vec();
std::vector<size_t> bias_shape{bias_sizes.begin(), bias_sizes.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape,
......@@ -856,8 +915,18 @@ std::vector<at::Tensor> fused_attn_fwd(
// allocate memory for nvte_aux_tensor_pack.tensors
at::Tensor output_tensor;
if (nvte_aux_tensor_pack.size >= 2) {
output_tensor = (i < nvte_aux_tensor_pack.size-1)
? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
if (i < nvte_aux_tensor_pack.size - 2) {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
} else if (i == nvte_aux_tensor_pack.size - 2) {
output_tensor = rng_state;
} else if (i == nvte_aux_tensor_pack.size - 1) {
output_tensor = Bias.value();
}
} else {
output_tensor = (i < nvte_aux_tensor_pack.size-1)
? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state;
}
} else {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
}
......@@ -988,7 +1057,7 @@ std::vector<at::Tensor> fused_attn_bwd(
at::Tensor dBias;
TensorWrapper te_dBias;
if (bias_type != NVTE_NO_BIAS) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
dBias = torch::empty({1, static_cast<int64_t>(Q.size(-2)),
static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options);
......@@ -999,9 +1068,9 @@ std::vector<at::Tensor> fused_attn_bwd(
TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
auto h_q = Q.size(-2);
auto h_kv = K.size(-2);
auto d = Q.size(-1);
auto h_q = q_shape[q_shape.size() - 2];
auto h_kv = k_shape[k_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
if (set_zero
&& ((h_q * d) % block_size == 0)
&& ((h_kv * d) % block_size == 0)
......@@ -1078,7 +1147,7 @@ std::vector<at::Tensor> fused_attn_bwd(
std::vector<size_t> cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()};
auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec();
std::vector<size_t> cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()};
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv, te_qkvso_strides;
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
......
......@@ -149,17 +149,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type")
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS)
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)
.value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI);
py::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type")
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK);
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK)
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK);
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
.value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)
.value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
.value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED)
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD)
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D)
.value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD)
......
......@@ -228,7 +228,7 @@ class _NoopCat(torch.autograd.Function):
), "Dimensions not compatible for concatenation"
param_temp = full_param_buffer.new()
param_temp.set_(full_param_buffer.storage(),
param_temp.set_(full_param_buffer.untyped_storage(),
full_param_buffer.storage_offset(),
full_param_buffer.size(),
full_param_buffer.stride())
......
......@@ -70,8 +70,8 @@ class TransformerLayer(torch.nn.Module):
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`self_attn_mask_type` is set to `"causal"`.
Argument :attr:`attention_mask` in the `forward` call is only used when
:attr:`self_attn_mask_type` includes `"padding"` or `"arbitrary"`.
Parameters
----------
......@@ -127,7 +127,8 @@ class TransformerLayer(torch.nn.Module):
kv_channels: int, default = `None`
number of key-value channels. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
default = `causal`
type of attention mask passed into softmax operation. Overridden by
:attr:`self_attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different
......@@ -491,7 +492,7 @@ class TransformerLayer(torch.nn.Module):
attention_mask: Optional[torch.Tensor] = None,
self_attn_mask_type: Optional[str] = None,
encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[InferenceParams] = None,
......@@ -512,17 +513,23 @@ class TransformerLayer(torch.nn.Module):
----------
hidden_states : torch.Tensor
Input tensor.
attention_mask : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input.
Can be a tuple of 2 masks for cross attention with padding masks.
self_attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
type of attention mask passed into softmax operation.
It should be in [batch_size, 1, 1, seqlen_q] for 'padding' mask,
and broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
for 'arbitrary'. It should be 'None' for 'causal' and 'no_mask'.
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `causal`
Type of attention mask passed into softmax operation.
encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
enc_dec_attn_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out inter-attention softmax input if using
`layer_type="decoder"`.
enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensors used to mask out inter-attention softmax input if
using `layer_type="decoder"`. It should be a tuple of two masks in
[batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for 'padding' mask.
It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
for 'arbitrary' mask. It should be 'None' for 'causal' and 'no_mask'.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
......@@ -545,7 +552,7 @@ class TransformerLayer(torch.nn.Module):
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`}
Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T
fast_zero_fill: bool, default = `True`
......@@ -569,7 +576,9 @@ class TransformerLayer(torch.nn.Module):
hidden_states.shape[0] == self.seq_length // self.tp_size
), "Sequence dimension must be split across TP group when using sequence parallel."
if self_attn_mask_type != "causal" and attention_mask is not None:
if (("padding" in self_attn_mask_type
or self_attn_mask_type == "arbitrary")
and attention_mask is not None):
assert (
attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment