Unverified Commit 32db3928 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Integrate cuDNN frontend v1 to fused attention (#497)



* Integrate cuDNN frontend v1 to fused attention and miscellaneous fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax/paddle for unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax/pytorch lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* simplify stride generation
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix and/or logic in get_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix flag_max512 and test_numerics
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove v.contiguous() since get_qkv_layout covers it
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* skip fp8 tests for sm89
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further fix jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert mask type to comma-separated list
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last two commits
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* integrate v1/pre-release-5
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* cleanup prerelease5 integration and fix FA2.1 commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force dropout to 0 if not training
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix Jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* testing bias/alibi and padding+causal; add alibi to unfused DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* set flag_arb to false when non determinism is not allowed
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* followup on prev commit; remove redundant python env var setting
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: minor tweaks for tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* prepare for tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix determinism logic for fused attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add bias to bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix gpt_checkpointing/dpa_accuracy problem
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix some seg fault issues
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add failure notes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove use of non-deter var for backend selection
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for lint and CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix workspace size in bwd and uncomment bias test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix get_alibi and remove check_support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update tests status
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove workspace_opt from FADescriptor_v1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable arbitrary backend + post scale bias in Jax; waiting on PR 525
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up bhsd order
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* swap bias/rng_state order in aux_ctx_tensor and add bias to aux_ctx_tensor in _qkvpacked/_kvpacked API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove support for padding_causal + cross for max512
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change alibi bias to float32 for bias_1_4/5 tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further clean up tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix thd fwd output shape for FlashAttention and add backend info for DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix definition of workspace limit when dbias is present
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further tweak DP_WORKSPACE_LIMIT definition
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disallow alibi+no_mask for sdpa flash and update alibi tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update jax/paddle after PR525 and fix DP_WORKSPACE_LIMIT for dbias Jax tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable dbias for non-hopper archs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix layernorm lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remode unused arg for lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove build dir in setup.py
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change selection logic to prefer fused attn on sm90
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix distributed jax test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix h and s order in header
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update to cudnn fe v1 branch
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove manual setting of workopt path due to dbias after v1 update
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix paddle CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add post_scale_bias and alibi to sdpa flash support matrix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix support matrix in header files
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move headers back to .cu and change seed/offset to int64
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update Megatron commit in L1 test and remove all prints in fused attn test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix L1 Megatron test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 arg in L1 Megatron script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* print only when debug flag is on
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove checkpointing loading to avoid loading other tests results
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent ff760a9d
...@@ -18,109 +18,56 @@ extern "C" { ...@@ -18,109 +18,56 @@ extern "C" {
#endif #endif
/*! \enum NVTE_QKV_Layout /*! \enum NVTE_QKV_Layout
* \brief Memory layouts of QKV tensors * \brief Memory layouts of QKV tensors.
* `S`, `B`, `H`, `D`, and `T` stand for sequence length, batch size, the number of heads, * `S`, `B`, `H`, `D`, and `T` stand for sequence length, batch size, number of heads,
head size, and the total number of sequences in a batch, i.e. `t = sum(s_i) for i = 0...b-1`. * head size, and the total number of sequences in a batch, i.e. `t = sum(s_i) for i = 0...b-1`.
`SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length * `SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length
or padded to the same length, and `THD`-based layouts are used when sequences have * or padded to the same length, and `THD`-based layouts are used when sequences have
different lengths in a batch. * different lengths in a batch.
* \note {`NVTE_QKV_INTERLEAVED`, `NVTE_KV_INTERLEAVED` and `NVTE_NOT_INTERLEAVED`
will be deprecated in the next release. Please use their equivalent enums instead, i.e. `NVTE_T3HD`,
`NVTE_THD_T2HD` and `NVTE_THD_THD_THD` when sequences are of variable lengths, and `NVTE_BS3HD`,
`NVTE_BSHD_BS2HD` and `NVTE_BSHD_BSHD_BSHD` when sequences are of equal length or padded
to equal length.}
*/ */
enum NVTE_QKV_Layout { enum NVTE_QKV_Layout {
/*! Separate Q, K, V tensors. NVTE_SB3HD = 0, /*!< SB3HD layout */
\verbatim NVTE_SBH3D = 1, /*!< SBH3D layout */
Q: [total_seqs_q, num_heads, head_dim] NVTE_SBHD_SB2HD = 2, /*!< SBHD_SB2HD layout */
| Q Q Q ... Q NVTE_SBHD_SBH2D = 3, /*!< SBHD_SBH2D layout */
| \___________ _____________/ NVTE_SBHD_SBHD_SBHD = 4, /*!< SBHD_SBHD_SBHD layout */
total_seqs_q <| \/ NVTE_BS3HD = 5, /*!< BS3HD layout */
| num_heads * head_dim NVTE_BSH3D = 6, /*!< BSH3D layout */
K: [total_seqs_kv, num_heads, head_dim] NVTE_BSHD_BS2HD = 7, /*!< BSHD_BS2HD layout */
| K K K ... K NVTE_BSHD_BSH2D = 8, /*!< BSHD_BSH2D layout */
| \___________ _____________/ NVTE_BSHD_BSHD_BSHD = 9, /*!< BSHD_BSHD_BSHD layout */
total_seqs_kv <| \/ NVTE_T3HD = 10, /*!< T3HD layout */
| num_heads * head_dim NVTE_TH3D = 11, /*!< TH3D layout */
V: [total_seqs_kv, num_heads, head_dim] NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */
| V V V ... V NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */
| \___________ _____________/ NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */
total_seqs_kv <| \/
| num_heads * head_dim
\endverbatim
*/
NVTE_NOT_INTERLEAVED = 0,
/*! Packed QKV.
\verbatim
QKV: [total_seqs, 3, num_heads, head_dim]
| Q Q Q ... Q K K K ... K V V V ... V
| \___________ _____________/
total_seqs <| \/
| num_heads * head_dim
\endverbatim
*/
NVTE_QKV_INTERLEAVED = 1,
/*! Q and packed KV.
\verbatim
Q: [total_seqs_q, num_heads, head_dim]
| Q Q Q ... Q
| \___________ _____________/
total_seqs_q <| \/
| num_heads * head_dim
KV: [total_seqs_kv, 2, num_heads, head_dim]
| K K K ... K V V V ... V
| \___________ _____________/
total_seqs_kv <| \/
| num_heads * head_dim
\endverbatim
*/
NVTE_KV_INTERLEAVED = 2,
NVTE_SB3HD = 3,
NVTE_SBH3D = 4,
NVTE_SBHD_SB2HD = 5,
NVTE_SBHD_SBH2D = 6,
NVTE_SBHD_SBHD_SBHD = 7,
NVTE_BS3HD = 8,
NVTE_BSH3D = 9,
NVTE_BSHD_BS2HD = 10,
NVTE_BSHD_BSH2D = 11,
NVTE_BSHD_BSHD_BSHD = 12,
NVTE_T3HD = 13,
NVTE_TH3D = 14,
NVTE_THD_T2HD = 15,
NVTE_THD_TH2D = 16,
NVTE_THD_THD_THD = 17,
}; };
/*! \enum NVTE_QKV_Layout_Group /*! \enum NVTE_QKV_Layout_Group
* \brief Grouping of QKV layouts * \brief QKV layout groups
*/ */
enum NVTE_QKV_Layout_Group { enum NVTE_QKV_Layout_Group {
/*! 3HD QKV layouts, e.g. BS3HD */ /*! 3HD QKV layouts, i.e. BS3HD, SB3HD, T3HD */
NVTE_3HD = 0, NVTE_3HD = 0,
/*! H3D QKV layouts, e.g. BSH3D */ /*! H3D QKV layouts, i.e. BSH3D, SBH3D, TH3D */
NVTE_H3D = 1, NVTE_H3D = 1,
/*! HD_2HD QKV layouts, e.g. BSHD_BS2HD */ /*! HD_2HD QKV layouts, i.e. BSHD_BS2HD, SBHD_SB2HD, THD_T2HD */
NVTE_HD_2HD = 2, NVTE_HD_2HD = 2,
/*! HD_H2D QKV layouts, e.g. BSHD_BSH2D */ /*! HD_H2D QKV layouts, i.e. BSHD_BSH2D, SBHD_SBH2D, THD_TH2D */
NVTE_HD_H2D = 3, NVTE_HD_H2D = 3,
/*! HD_HD_HD QKV layouts, e.g. BSHD_BSHD_BSHD */ /*! HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD */
NVTE_HD_HD_HD = 4, NVTE_HD_HD_HD = 4,
}; };
/*! \enum NVTE_QKV_Format /*! \enum NVTE_QKV_Format
* \brief Dimension formats for QKV tensors * \brief QKV formats
*/ */
enum NVTE_QKV_Format { enum NVTE_QKV_Format {
/*! SBHD QKV format */ /*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD */
NVTE_SBHD = 0, NVTE_SBHD = 0,
/*! BSHD QKV format */ /*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD */
NVTE_BSHD = 1, NVTE_BSHD = 1,
/*! THD QKV format */ /*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */
NVTE_THD = 2, NVTE_THD = 2,
}; };
...@@ -133,7 +80,9 @@ enum NVTE_Bias_Type { ...@@ -133,7 +80,9 @@ enum NVTE_Bias_Type {
/*! Bias before scale */ /*! Bias before scale */
NVTE_PRE_SCALE_BIAS = 1, NVTE_PRE_SCALE_BIAS = 1,
/*! Bias after scale */ /*! Bias after scale */
NVTE_POST_SCALE_BIAS = 2 NVTE_POST_SCALE_BIAS = 2,
/*! ALiBi */
NVTE_ALIBI = 3,
}; };
/*! \enum NVTE_Mask_Type /*! \enum NVTE_Mask_Type
...@@ -164,7 +113,7 @@ enum NVTE_Fused_Attn_Backend { ...@@ -164,7 +113,7 @@ enum NVTE_Fused_Attn_Backend {
NVTE_FP8 = 2, NVTE_FP8 = 2,
}; };
/*! \brief Get layout group for a given QKV layout /*! \brief Get QKV layout group for a given QKV layout.
* *
* \param[in] qkv_layout QKV layout, e.g. sbh3d. * \param[in] qkv_layout QKV layout, e.g. sbh3d.
* *
...@@ -172,7 +121,7 @@ enum NVTE_Fused_Attn_Backend { ...@@ -172,7 +121,7 @@ enum NVTE_Fused_Attn_Backend {
*/ */
NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout); NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout);
/*! \brief Get QKV format for a given QKV layout /*! \brief Get QKV format for a given QKV layout.
* *
* \param[in] qkv_layout QKV layout, e.g. sbh3d. * \param[in] qkv_layout QKV layout, e.g. sbh3d.
* *
...@@ -188,6 +137,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout); ...@@ -188,6 +137,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] bias_type The attention bias type. * \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type. * \param[in] attn_mask_type The attention mask type.
* \param[in] dropout The dropout probability. * \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V.
* \param[in] max_seqlen_q The sequence length of Q. * \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V. * \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim The head dimension of Q, K, V. * \param[in] head_dim The head dimension of Q, K, V.
...@@ -198,8 +149,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -198,8 +149,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_QKV_Layout qkv_layout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Mask_Type attn_mask_type,
float dropout, size_t max_seqlen_q, float dropout,
size_t max_seqlen_kv, size_t head_dim); size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv,
size_t head_dim);
/*! \brief Compute dot product attention with packed QKV input. /*! \brief Compute dot product attention with packed QKV input.
* *
...@@ -211,14 +164,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -211,14 +164,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 | | 0 | FP16/BF16 | BS3HD,SB3HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 | | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
\endverbatim \endverbatim
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* [total_seqs, 3, num_heads, head_dim].
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor. * \param[in,out] S The S tensor.
* \param[out] O The output O tensor. * \param[out] O The output O tensor.
...@@ -256,14 +208,13 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -256,14 +208,13 @@ void nvte_fused_attn_fwd_qkvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 | | 0 | FP16/BF16 | BS3HD,SB3HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 | | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
\endverbatim \endverbatim
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* [total_seqs, 3, num_heads, head_dim].
* \param[in] O The O tensor from forward. * \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor. * \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor. * \param[in] S The S tensor.
...@@ -310,12 +261,13 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -310,12 +261,13 @@ void nvte_fused_attn_bwd_qkvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 | | 0 | FP16/BF16 | BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 |
| 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
\endverbatim \endverbatim
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. * \param[in] KV The KV tensor, in 2HD or H2D layouts.
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor. * \param[in,out] S The S tensor.
* \param[out] O The output O tensor. * \param[out] O The output O tensor.
...@@ -358,12 +310,13 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -358,12 +310,13 @@ void nvte_fused_attn_fwd_kvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 | | 0 | FP16/BF16 | BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 |
| 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
\endverbatim \endverbatim
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. * \param[in] KV The KV tensor, in H2D or 2HD layouts.
* \param[in] O The O tensor from forward. * \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor. * \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor. * \param[in] S The S tensor.
...@@ -417,10 +370,12 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -417,10 +370,12 @@ void nvte_fused_attn_bwd_kvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 | | 0 | FP16/BF16 | BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 |
| 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 | | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
| 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | | | | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | | | | | |
| | | BSHD_BSHD_BSHD,SBHD_SBHD_SBHD | | | | | |
| 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
\endverbatim \endverbatim
* *
* \param[in] Q The Q tensor. * \param[in] Q The Q tensor.
...@@ -469,10 +424,12 @@ void nvte_fused_attn_fwd( ...@@ -469,10 +424,12 @@ void nvte_fused_attn_fwd(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 | | 0 | FP16/BF16 | BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 |
| 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 | | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
| 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | | | | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | | | | | |
| | | BSHD_BSHD_BSHD,SBHD_SBHD_SBHD | | | | | |
| 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
\endverbatim \endverbatim
* *
* \param[in] Q The Q tensor. * \param[in] Q The Q tensor.
......
...@@ -1637,6 +1637,8 @@ class FusedAttnHelper: ...@@ -1637,6 +1637,8 @@ class FusedAttnHelper:
attn_bias_type: NVTE_Bias_Type attn_bias_type: NVTE_Bias_Type
attn_mask_type: NVTE_Mask_Type attn_mask_type: NVTE_Mask_Type
dropout_probability: float dropout_probability: float
num_heads_q: int
num_heads_kv: int
max_seqlen_q: int max_seqlen_q: int
max_seqlen_kv: int max_seqlen_kv: int
head_dim: int head_dim: int
...@@ -1652,6 +1654,7 @@ class FusedAttnHelper: ...@@ -1652,6 +1654,7 @@ class FusedAttnHelper:
self.qkv_layout, self.attn_bias_type, self.qkv_layout, self.attn_bias_type,
self.attn_mask_type, self.attn_mask_type,
self.dropout_probability, self.dropout_probability,
self.num_heads_q, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_kv, self.max_seqlen_q, self.max_seqlen_kv,
self.head_dim) self.head_dim)
...@@ -1733,8 +1736,8 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -1733,8 +1736,8 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
output_dtype = qkv_dtype output_dtype = qkv_dtype
backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type, backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type,
attn_mask_type, dropout_probability, max_seqlen, max_seqlen, attn_mask_type, dropout_probability, num_head, num_head,
head_dim).get_fused_attn_backend() max_seqlen, max_seqlen, head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (*batch_shape, num_head, max_seqlen, max_seqlen) softmax_aux_shape = (*batch_shape, num_head, max_seqlen, max_seqlen)
...@@ -2087,8 +2090,9 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2087,8 +2090,9 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
output_dtype = q_dtype output_dtype = q_dtype
backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD, backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
attn_bias_type, attn_mask_type, dropout_probability, q_max_seqlen, attn_bias_type, attn_mask_type, dropout_probability,
kv_max_seqlen, q_head_dim).get_fused_attn_backend() q_num_head, kv_num_head,
q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen) softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen)
......
...@@ -740,11 +740,13 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, ...@@ -740,11 +740,13 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim) { size_t head_dim) {
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim); mask_type, dropout_probability, q_num_heads, kv_num_heads,
q_max_seqlen, kv_max_seqlen, head_dim);
return backend; return backend;
} }
...@@ -799,7 +801,8 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu ...@@ -799,7 +801,8 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim); mask_type, dropout_probability, num_head, num_head,
q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
...@@ -975,7 +978,8 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -975,7 +978,8 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim); mask_type, dropout_probability, num_head, num_head,
q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
......
...@@ -117,6 +117,7 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( ...@@ -117,6 +117,7 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim); size_t head_dim);
......
...@@ -444,8 +444,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -444,8 +444,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout, has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout,
attn_bias_type, attn_mask_type, attn_bias_type, attn_mask_type,
self.dropout_rate, q_seqlen, self.dropout_rate,
kv_seqlen, self.head_dim) self.num_heads, self.num_heads,
q_seqlen, kv_seqlen, self.head_dim)
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \ use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \ canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
......
...@@ -40,13 +40,14 @@ class QKVLayout(Enum): ...@@ -40,13 +40,14 @@ class QKVLayout(Enum):
def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type, def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
dropout_probability, max_seqlen_q, max_seqlen_kv, head_dim): dropout_probability, num_heads_q, num_heads_kv,
max_seqlen_q, max_seqlen_kv, head_dim):
""" """
To check whether the fused attention kernel is available To check whether the fused attention kernel is available
""" """
return FusedAttnHelper(q_type, kv_type, qkv_layout.value, attn_bias_type.value, return FusedAttnHelper(q_type, kv_type, qkv_layout.value, attn_bias_type.value,
attn_mask_type.value, dropout_probability, max_seqlen_q, max_seqlen_kv, attn_mask_type.value, dropout_probability, num_heads_q, num_heads_kv,
head_dim).is_fused_attn_kernel_available() max_seqlen_q, max_seqlen_kv, head_dim).is_fused_attn_kernel_available()
def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray, def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
......
...@@ -128,10 +128,12 @@ inline DType Int2NvteDType(int64_t dtype) { ...@@ -128,10 +128,12 @@ inline DType Int2NvteDType(int64_t dtype) {
inline NVTE_Fused_Attn_Backend get_fused_attn_backend( inline NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim) { float p_dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, max_seqlen_q, max_seqlen_kv, head_dim); attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv, head_dim);
return fused_attention_backend; return fused_attention_backend;
} }
......
...@@ -237,8 +237,10 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -237,8 +237,10 @@ class DotProductAttention(paddle.nn.Layer):
self.fused_attention_backend = tex.get_fused_attn_backend( self.fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype], TE_DType[query_layer.dtype], TE_DType[query_layer.dtype], TE_DType[query_layer.dtype],
tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type], tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type],
AttnMaskType[self.attn_mask_type], self.attention_dropout, max_s_q, max_s_kv, AttnMaskType[self.attn_mask_type], self.attention_dropout,
query_layer.shape[-1]) query_layer.shape[-2],
key_value_layer.shape[-2] if key_value_layer is not None else query_layer.shape[-2],
max_s_q, max_s_kv, query_layer.shape[-1])
is_backend_avail = (self.fused_attention_backend in [ is_backend_avail = (self.fused_attention_backend in [
FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"] FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]
......
...@@ -69,6 +69,7 @@ else: ...@@ -69,6 +69,7 @@ else:
_cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv = None, None, None, None _cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv = None, None, None, None
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"] __all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
...@@ -119,13 +120,59 @@ class InferenceParams: # pylint: disable=too-few-public-methods ...@@ -119,13 +120,59 @@ class InferenceParams: # pylint: disable=too-few-public-methods
new_inference_value_memory, new_inference_value_memory,
) )
@torch.no_grad()
def get_alibi(
num_heads: int,
max_seqlen_q: int,
max_seqlen_kv: int,
) -> torch.Tensor:
"""
Generate ALiBi bias in the shape of [1, num_heads, max_seqlen_q, max_seqlen_kv].
"""
n = 2 ** math.floor(math.log2(num_heads))
m_0 = 2.0 ** (-8.0 / n)
m = torch.pow(m_0, torch.arange(1, 1 + n))
if n < num_heads:
m_hat_0 = 2.0 ** (-4.0 / n)
m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
m = torch.cat([m, m_hat])
a = torch.ones(max_seqlen_q, max_seqlen_kv)
b = torch.triu(a,diagonal=1)
c = b.cumsum(dim=-1)
bb = torch.tril(a,diagonal=-1)
cc = bb.cumsum(dim=0)
d = c - cc
bias = d.repeat(1, num_heads, 1, 1)
for i in range(num_heads):
bias[0,i,:,:] = m[i] * bias[0,i,:,:]
bias = bias.to(dtype=torch.float32, device="cuda")
return bias
def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
"""
Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
the samples in a batch.
"""
mask = mask.squeeze(1).squeeze(1)
reduced_mask = mask.sum(dim=1)
cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
zero = torch.zeros(1, dtype=torch.int32, device="cuda")
cu_seqlens = torch.cat((zero, cu_seqlens))
return cu_seqlens
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32 Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
tensor of shape [batch_size + 1,] containing the cumulative sequence tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
lengths of every sample in the batch and the indices containing valid the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
samples. containing the indices for the valid tokens.
""" """
mask = mask.squeeze(1).squeeze(1) mask = mask.squeeze(1).squeeze(1)
bs, seqlen = mask.shape bs, seqlen = mask.shape
...@@ -147,6 +194,26 @@ def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch. ...@@ -147,6 +194,26 @@ def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.
return cu_seqlens, indices return cu_seqlens, indices
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
"""
Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
the valid tokens in a batch.
"""
bs = len(cu_seqlens) - 1
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
indices = [i*max_seqlen + ii for i,j in enumerate(seqlens) for ii in range(j)]
indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(
dtype=torch.int64, device="cuda")
num_nonzeros = indices.shape[0]
pad_amount = bs * max_seqlen - num_nonzeros
indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount),
mode="constant", value=float(bs * max_seqlen))
return indices
@jit_fuser @jit_fuser
def pack_tensor( def pack_tensor(
indices: torch.Tensor, indices: torch.Tensor,
...@@ -290,34 +357,6 @@ class UnpackTensor(torch.autograd.Function): ...@@ -290,34 +357,6 @@ class UnpackTensor(torch.autograd.Function):
return None, None, pack_tensor(ctx.indices, grad_output) return None, None, pack_tensor(ctx.indices, grad_output)
def _unpack_attn_mask_type(attn_mask_type: str) -> Tuple[str, bool]:
"""
Unpacks the attention mask type string and returns a single mask type
and a boolean for whether to apply causal mask. Also ensures that the
combination of masks passed in is supported by one of the attention
backends available.
"""
mask_types = attn_mask_type.split(',')
assert (
all(mask_type in AttnMaskTypes for mask_type in mask_types)
), f"Mask type {attn_mask_type} is not supported."
# Whether or not to apply causal mask toggle.
causal_mask = False
if "causal" in mask_types:
mask_types.remove("causal")
causal_mask = True
if len(mask_types) == 0: # Only apply causal mask.
return "causal", True
if len(mask_types) == 1 and causal_mask: # Causal + padding masks
assert mask_types[0] == "padding", f"Causal + {mask_types[0]} masking not supported."
return "padding", True
if len(mask_types) == 1: # Arbitrary or padding or no_mask
return mask_types[0], False
raise RuntimeError("Unsupported combination of mask types.")
def flash_attn_p2p_communicate(rank, send_tensor, send_dst, def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
recv_tensor, recv_src, recv_tensor, recv_src,
cp_group, batch_p2p_comm): cp_group, batch_p2p_comm):
...@@ -878,7 +917,6 @@ class _SplitAlongDim(torch.autograd.Function): ...@@ -878,7 +917,6 @@ class _SplitAlongDim(torch.autograd.Function):
return torch.cat(grad_outputs, dim = split_dim), None, None return torch.cat(grad_outputs, dim = split_dim), None, None
class UnfusedDotProductAttention(torch.nn.Module): class UnfusedDotProductAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms """Parallel attention w/o QKV and Proj Gemms
BMM1 -> softmax + dropout -> BMM2 BMM1 -> softmax + dropout -> BMM2
...@@ -921,7 +959,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -921,7 +959,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""core attention fprop""" """Unfused attention fprop"""
assert (qkv_layout in QKVLayouts assert (qkv_layout in QKVLayouts
), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!" ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
...@@ -932,9 +970,6 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -932,9 +970,6 @@ class UnfusedDotProductAttention(torch.nn.Module):
# convert to sbhd and use sbhd implementation for now # convert to sbhd and use sbhd implementation for now
query_layer, key_layer, value_layer = [x.transpose(0, 1) query_layer, key_layer, value_layer = [x.transpose(0, 1)
for x in [query_layer, key_layer, value_layer]] for x in [query_layer, key_layer, value_layer]]
assert (
attn_mask_type in AttnMaskTypes
), f"attn_mask_type {attn_mask_type} not supported"
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
...@@ -1003,10 +1038,13 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -1003,10 +1038,13 @@ class UnfusedDotProductAttention(torch.nn.Module):
+ core_attention_bias).view(-1, output_size[2], output_size[3]) + core_attention_bias).view(-1, output_size[2], output_size[3])
matmul_result /= scale matmul_result /= scale
elif core_attention_bias_type == "post_scale_bias": elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
assert core_attention_bias is not None, "core_attention_bias should not be None!" if core_attention_bias_type == "post_scale_bias":
assert (core_attention_bias.shape == torch.Size([1, *output_size[1:]]) assert core_attention_bias is not None, "core_attention_bias should not be None!"
), "core_attention_bias must be in [1, h, sq, skv] shape!" assert (core_attention_bias.shape == torch.Size([1, *output_size[1:]])
), "core_attention_bias must be in [1, h, sq, skv] shape!"
if core_attention_bias_type == "alibi":
core_attention_bias = get_alibi(output_size[1], output_size[2], output_size[3])
matmul_result = torch.baddbmm( matmul_result = torch.baddbmm(
matmul_result, matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn] query_layer.transpose(0, 1), # [b * np, sq, hn]
...@@ -1016,7 +1054,8 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -1016,7 +1054,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
) )
matmul_result = (matmul_result.view( matmul_result = (matmul_result.view(
output_size[0], output_size[1], output_size[2], output_size[3]) output_size[0], output_size[1], output_size[2], output_size[3])
+ core_attention_bias).view(-1, output_size[2], output_size[3]) + core_attention_bias).view(-1, output_size[2], output_size[3]).to(
dtype=query_layer.dtype)
# change view to [b, np, sq, sk] # change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size) attention_scores = matmul_result.view(*output_size)
...@@ -1173,7 +1212,6 @@ def _get_qkv_layout( ...@@ -1173,7 +1212,6 @@ def _get_qkv_layout(
check_last_two_dims_offsets_kv = all(i * last_two_dims_size == x.storage_offset() check_last_two_dims_offsets_kv = all(i * last_two_dims_size == x.storage_offset()
for i, x in enumerate([k, v])) for i, x in enumerate([k, v]))
qkv_layout = None
if (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv if (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv
and check_last_two_dims_offsets_qkv and check_last_two_dims_offsets_qkv
and not check_last_dim_offsets_qkv): and not check_last_dim_offsets_qkv):
...@@ -1297,7 +1335,7 @@ class FlashAttention(torch.nn.Module): ...@@ -1297,7 +1335,7 @@ class FlashAttention(torch.nn.Module):
for x in [query_layer, key_layer, value_layer] for x in [query_layer, key_layer, value_layer]
] ]
if attn_mask_type == 'padding': if 'padding' in attn_mask_type:
assert not context_parallel, "Padding mask not supported with context parallelism." assert not context_parallel, "Padding mask not supported with context parallelism."
if self.attention_type == "self": if self.attention_type == "self":
...@@ -1305,15 +1343,31 @@ class FlashAttention(torch.nn.Module): ...@@ -1305,15 +1343,31 @@ class FlashAttention(torch.nn.Module):
max_seqlen_q == max_seqlen_kv max_seqlen_q == max_seqlen_kv
), "Maximum sequence length for Q and KV should be the same." ), "Maximum sequence length for Q and KV should be the same."
if self.layer_number == 1: if self.layer_number == 1:
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask) if cu_seqlens_q is None:
assert (attention_mask is not None
), "Please provide attention_mask for padding!"
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask)
else:
_cu_seqlens_q = cu_seqlens_q
_indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
_cu_seqlens_kv = _cu_seqlens_q _cu_seqlens_kv = _cu_seqlens_q
query_layer_packed, key_layer_packed, value_layer_packed = PackTensors.apply( query_layer_packed, key_layer_packed, value_layer_packed = PackTensors.apply(
_indices_q, query_layer, key_layer, value_layer _indices_q, query_layer, key_layer, value_layer
) )
else: else:
if self.layer_number == 1: if self.layer_number == 1:
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask[0]) if cu_seqlens_q is None or cu_seqlens_kv is None:
_cu_seqlens_kv, _indices_kv = get_cu_seqlens_and_indices(attention_mask[1]) assert (attention_mask is not None
), "Please provide attention_mask for padding!"
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(
attention_mask[0])
_cu_seqlens_kv, _indices_kv = get_cu_seqlens_and_indices(
attention_mask[1])
else:
_cu_seqlens_q = cu_seqlens_q
_cu_seqlens_kv = cu_seqlens_kv
_indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
_indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
query_layer_packed = PackTensors.apply(_indices_q, query_layer) query_layer_packed = PackTensors.apply(_indices_q, query_layer)
key_layer_packed, value_layer_packed = PackTensors.apply( key_layer_packed, value_layer_packed = PackTensors.apply(
_indices_kv, key_layer, value_layer _indices_kv, key_layer, value_layer
...@@ -1355,7 +1409,7 @@ class FlashAttention(torch.nn.Module): ...@@ -1355,7 +1409,7 @@ class FlashAttention(torch.nn.Module):
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream, cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor, softmax_scale=1.0/self.norm_factor,
causal=attn_mask_type=="causal", causal="causal" in attn_mask_type,
deterministic=self.deterministic deterministic=self.deterministic
) )
else: else:
...@@ -1367,11 +1421,11 @@ class FlashAttention(torch.nn.Module): ...@@ -1367,11 +1421,11 @@ class FlashAttention(torch.nn.Module):
query_layer, key_layer, value_layer, query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor, causal=attn_mask_type=="causal", softmax_scale=1.0/self.norm_factor, causal="causal" in attn_mask_type,
**fa_optional_forward_kwargs **fa_optional_forward_kwargs
) )
if attn_mask_type == 'padding': if 'padding' in attn_mask_type:
output = UnpackTensor.apply(_indices_q, batch_size * max_seqlen_q, output) output = UnpackTensor.apply(_indices_q, batch_size * max_seqlen_q, output)
if qkv_format == 'sbhd': if qkv_format == 'sbhd':
...@@ -1380,6 +1434,9 @@ class FlashAttention(torch.nn.Module): ...@@ -1380,6 +1434,9 @@ class FlashAttention(torch.nn.Module):
elif qkv_format == 'bshd': elif qkv_format == 'bshd':
# (bs)hd -> bs(hd) # (bs)hd -> bs(hd)
output = output.view(batch_size, max_seqlen_q, -1).contiguous() output = output.view(batch_size, max_seqlen_q, -1).contiguous()
elif qkv_format == 'thd':
# thd -> t(hd)
output = output.view(output.shape[0], -1).contiguous()
return output return output
...@@ -1416,6 +1473,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -1416,6 +1473,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
qkv, out, cu_seqlens = ctx.saved_tensors qkv, out, cu_seqlens = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd: if ctx.use_FAv2_bwd:
softmax_lse, rng_state = ctx.aux_ctx_tensors softmax_lse, rng_state = ctx.aux_ctx_tensors
dqkv = torch.empty_like(qkv) dqkv = torch.empty_like(qkv)
...@@ -1426,7 +1485,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -1426,7 +1485,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
d_out, q, k, v, out, softmax_lse, dqkv[:,0], dqkv[:,1], dqkv[:,2], d_out, q, k, v, out, softmax_lse, dqkv[:,0], dqkv[:,1], dqkv[:,2],
cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen, cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
ctx.dropout_p, ctx.attn_scale, False, ctx.dropout_p, ctx.attn_scale, False,
ctx.attn_mask_type == "causal", None, rng_state "causal" in ctx.attn_mask_type, None, rng_state
) )
dqkv = dqkv[..., :d_out.shape[-1]] dqkv = dqkv[..., :d_out.shape[-1]]
else: else:
...@@ -1438,8 +1497,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -1438,8 +1497,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias, return dqkv # if no_bias or alibi, return dqkv
if ctx.attn_bias_type == "no_bias": if ctx.attn_bias_type in ["no_bias", "alibi"]:
return (None, None, None, dqkv, None, None, None, return (None, None, None, dqkv, None, None, None,
None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None, None, None) None, None, None, None, None, None)
...@@ -1482,6 +1541,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -1482,6 +1541,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd: if ctx.use_FAv2_bwd:
softmax_lse, rng_state = ctx.aux_ctx_tensors softmax_lse, rng_state = ctx.aux_ctx_tensors
dq = torch.empty_like(q) dq = torch.empty_like(q)
...@@ -1493,7 +1554,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -1493,7 +1554,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
d_out, q, k, v, out, softmax_lse, dq, dkv[:,0], dkv[:,1], d_out, q, k, v, out, softmax_lse, dq, dkv[:,0], dkv[:,1],
cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv,
ctx.dropout_p, ctx.attn_scale, False, ctx.dropout_p, ctx.attn_scale, False,
ctx.attn_mask_type == "causal", None, rng_state "causal" in ctx.attn_mask_type, None, rng_state
) )
dq = dq[..., :d_out.shape[-1]] dq = dq[..., :d_out.shape[-1]]
dkv = dkv[..., :d_out.shape[-1]] dkv = dkv[..., :d_out.shape[-1]]
...@@ -1507,8 +1568,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -1507,8 +1568,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias, return dqkv # if no_bias or alibi, return dqkv
if ctx.attn_bias_type == "no_bias": if ctx.attn_bias_type in ["no_bias", "alibi"]:
return (None, None, None, None, None, dq, dkv, None, None, None, return (None, None, None, None, None, dq, dkv, None, None, None,
None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None, None, None) None, None, None, None, None, None)
...@@ -1551,6 +1612,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1551,6 +1612,8 @@ class FusedAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd: if ctx.use_FAv2_bwd:
softmax_lse, rng_state = ctx.aux_ctx_tensors softmax_lse, rng_state = ctx.aux_ctx_tensors
dq = torch.empty_like(q) dq = torch.empty_like(q)
...@@ -1563,7 +1626,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1563,7 +1626,7 @@ class FusedAttnFunc(torch.autograd.Function):
d_out, q, k, v, out, softmax_lse, dq, dk, dv, d_out, q, k, v, out, softmax_lse, dq, dk, dv,
cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv,
ctx.dropout_p, ctx.attn_scale, False, ctx.dropout_p, ctx.attn_scale, False,
ctx.attn_mask_type == "causal", None, rng_state "causal" in ctx.attn_mask_type, None, rng_state
) )
dq = dq[..., :d_out.shape[-1]] dq = dq[..., :d_out.shape[-1]]
dk = dk[..., :d_out.shape[-1]] dk = dk[..., :d_out.shape[-1]]
...@@ -1578,8 +1641,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1578,8 +1641,8 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias, return dqkv # if no_bias or alibi, return dqkv
if ctx.attn_bias_type == "no_bias": if ctx.attn_bias_type in ["no_bias", "alibi"]:
return (None, None, None, None, None, dq, dk, dv, None, None, None, return (None, None, None, None, None, dq, dk, dv, None, None, None,
None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None, None, None) None, None, None, None, None, None)
...@@ -1602,19 +1665,17 @@ class FusedAttention(torch.nn.Module): ...@@ -1602,19 +1665,17 @@ class FusedAttention(torch.nn.Module):
| flash based | no | yes | | flash based | no | yes |
| cuDNN based | yes | yes | | cuDNN based | yes | yes |
| qkv dtype | fp16/bf16 | fp16/bf16 | | qkv dtype | fp16/bf16 | fp16/bf16 |
| attn_type | self/cross | self | | attn_type | self/cross | self/cross |
| qkv_layout | | | | qkv_layout | | |
| - qkv | qkv_interleaved | qkv_interleaved |
| - (q,kv) | kv_interleaved | |
| - (q,k,v) | sb3hd, bs3hd | sb3hd, bs3hd, sbh3d, bsh3d | | - (q,k,v) | sb3hd, bs3hd | sb3hd, bs3hd, sbh3d, bsh3d |
| | sbhd_sb2hd, bshd_bs2hd | sbhd_sb2hd, bshd_bs2hd | | | sbhd_sb2hd, bshd_bs2hd | sbhd_sb2hd, bshd_bs2hd |
| | bshd_bshd_bshd | sbhd_sbh2d, bshd_bsh2d | | | bshd_bshd_bshd | sbhd_sbh2d, bshd_bsh2d |
| | | sbhd_sbhd_sbhd, bshd_bshd_bshd | | | | sbhd_sbhd_sbhd, bshd_bshd_bshd |
| mask_type | causal/no_mask | causal | | mask_type | causal/padding/no_mask | causal/padding/no_mask |
| bias_type | no_bias/post_scale_bias | no_bias | | bias_type | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias |
| dropout | yes | yes | | dropout | yes | yes |
| max_seqlen | <=512 | any | | max_seqlen | <=512, multiple of 64 | any, multiple of 64 |
| head_dim | 64 | 64,128 | | head_dim | 64 | <=128, multiple of 8 |
| output dtype | fp16/bf16 | fp16/bf16 | | output dtype | fp16/bf16 | fp16/bf16 |
""" """
...@@ -1624,6 +1685,8 @@ class FusedAttention(torch.nn.Module): ...@@ -1624,6 +1685,8 @@ class FusedAttention(torch.nn.Module):
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext, attention_dropout_ctx: Optional[Callable] = nullcontext,
attention_type: str = "self", attention_type: str = "self",
layer_number: Optional[int] = None,
deterministic: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1634,6 +1697,22 @@ class FusedAttention(torch.nn.Module): ...@@ -1634,6 +1697,22 @@ class FusedAttention(torch.nn.Module):
self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "0") == "1" self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "0") == "1"
and _flash_attn_2_available and _flash_attn_2_available
and get_device_compute_capability() == (9, 0)) and get_device_compute_capability() == (9, 0))
self.layer_number = 1 if layer_number is None else layer_number
if deterministic:
# workspace optimization path is deterministic
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
# CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
# - unset: enables workspace optimization when required workspace is <= 256MB
# or when bias gradient needs to be computed
# - n: enables workspace optimization when required workspace is <= n bytes
# - -1: enables workspace optimization always
# - 0: disables workspace optimization always
if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
def forward( def forward(
self, self,
...@@ -1644,6 +1723,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1644,6 +1723,7 @@ class FusedAttention(torch.nn.Module):
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
fused_attention_backend: fused_attention_backend:
tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
...@@ -1668,6 +1748,10 @@ class FusedAttention(torch.nn.Module): ...@@ -1668,6 +1748,10 @@ class FusedAttention(torch.nn.Module):
), f"FusedAttention does not support qkv_layout = {qkv_layout}!" ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
assert (
qkv_format != 'thd'
), 'FusedAttention does not support qkv_format = thd!'
if qkv_format in ['sbhd', 'bshd']: if qkv_format in ['sbhd', 'bshd']:
if qkv_format == 'sbhd': if qkv_format == 'sbhd':
batch_size, max_seqlen_q, max_seqlen_kv = ( batch_size, max_seqlen_q, max_seqlen_kv = (
...@@ -1675,31 +1759,48 @@ class FusedAttention(torch.nn.Module): ...@@ -1675,31 +1759,48 @@ class FusedAttention(torch.nn.Module):
if qkv_format == 'bshd': if qkv_format == 'bshd':
batch_size, max_seqlen_q, max_seqlen_kv = ( batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[0], query_layer.shape[1], key_layer.shape[1]) query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
if cu_seqlens_q is None: if 'padding' in attn_mask_type:
cu_seqlens_q = torch.arange( global _cu_seqlens_q, _cu_seqlens_kv
0, if (cu_seqlens_q is not None and cu_seqlens_kv is not None):
(batch_size + 1) * max_seqlen_q, # use cu_seqlens when both cu_seqlens and attention_mask are present
step=max_seqlen_q, if self.layer_number == 1:
dtype=torch.int32, _cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
device=query_layer.device) elif attention_mask is not None:
if cu_seqlens_kv is None: if self.attention_type == "self":
cu_seqlens_kv = torch.arange( if self.layer_number == 1:
0, _cu_seqlens_q = get_cu_seqlens(attention_mask)
(batch_size + 1) * max_seqlen_kv, _cu_seqlens_kv = _cu_seqlens_q
step=max_seqlen_kv, else:
dtype=torch.int32, if self.layer_number == 1:
device=key_layer.device) _cu_seqlens_q = get_cu_seqlens(attention_mask[0])
if qkv_format == 'thd': _cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None else:
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" raise Exception("Please provide attention_mask or cu_seqlens for padding!")
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] else:
max_seqlen_q = seqlens_q.max().item() if self.layer_number == 1:
max_seqlen_kv = seqlens_kv.max().item() if cu_seqlens_q is None:
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device)
if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange(
0,
(batch_size + 1) * max_seqlen_kv,
step=max_seqlen_kv,
dtype=torch.int32,
device=key_layer.device)
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
else:
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
qkv_dtype = TE_DType[query_layer.dtype] qkv_dtype = TE_DType[query_layer.dtype]
use_FAv2_bwd = (self.use_FAv2_bwd use_FAv2_bwd = (self.use_FAv2_bwd
and (core_attention_bias_type == "no_bias")
and (fused_attention_backend and (fused_attention_backend
== tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)) == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen))
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
...@@ -1733,7 +1834,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1733,7 +1834,7 @@ class DotProductAttention(torch.nn.Module):
.. note:: .. note::
Argument :attr:`attention_mask` in the `forward` call is only used when Argument :attr:`attention_mask` in the `forward` call is only used when
:attr:`self_attn_mask_type` includes `"padding"` or `"arbitrary"`. :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
.. warning:: .. warning::
...@@ -1759,18 +1860,20 @@ class DotProductAttention(torch.nn.Module): ...@@ -1759,18 +1860,20 @@ class DotProductAttention(torch.nn.Module):
attention_dropout: float, default = 0.0 attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention. dropout probability for the dropout op during multi-head attention.
attn_mask_type: str, default = `causal` attn_mask_type: str, default = `causal`
type of attention mask passed into softmax operation, options are "`causal`", type of attention mask passed into softmax operation, options are "`no_mask`",
"`padding`", "`arbitrary`", "`no_mask`". For the "`causal`" mask, "`padding`", "`causal`", "`padding,causal`", "`causal,padding`", and
TransformerEngine calculates and applies an upper triangular mask to "`arbitrary`", where "`padding,causal`" and "`causal,padding`" are equivalent.
the softmax input. An "`arbitrary`" mask is an arbitrary user defined mask This arg can be overridden by :attr:`attn_mask_type` in the `forward` method.
broadcastable to the shape of softmax input. The "`padding`" mask is used It is useful for cases involving compilation/tracing, e.g. ONNX export, and the
for providing locations of padded tokens in the batch, which should be of forward arg is useful for dynamically changing mask types, e.g. a different mask
the shape [batch_size, 1, 1, seq_len]. No mask is applied for the "`no_mask`" for training and inference. For "`no_mask`", no attention mask is applied. For
option. For the `"arbitrary"` and `"padding"` mask types, the argument "`causal`" or the causal mask in "`padding,causal`", TransformerEngine calculates
:attr:`attention_mask` must be passed into `forward` call. The "`causal`" and applies an upper triangular mask to the softmax input. No user input is
mask can also be applied in conjunction with "`padding`" mask by passing needed. For "`padding`" or the padding mask in "`padding,causal`", users need to
in multiple mask type as a comma separated string, for example, provide the locations of padded tokens either via :attr:`cu_seqlens_q` and
`attn_mask_type="causal,padding"`. :attr:`cu_seqlens_kv` in the shape of [batch_size + 1] or :attr:`attention_mask`
in the shape [batch_size, 1, 1, max_seq_len]. For the "`arbitrary`" mask, users
need to provide a mask that is broadcastable to the shape of softmax input.
attention_type: str, default = `self` attention_type: str, default = `self`
type of attention, either "`self`" and "`cross`". type of attention, either "`self`" and "`cross`".
layer_number: int, default = `None` layer_number: int, default = `None`
...@@ -1786,12 +1889,6 @@ class DotProductAttention(torch.nn.Module): ...@@ -1786,12 +1889,6 @@ class DotProductAttention(torch.nn.Module):
have different lengths. Please note that these formats do not reflect how have different lengths. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `_get_qkv_layout` to gain the layout information. For that, please use `_get_qkv_layout` to gain the layout information.
attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
type of attention mask passed into softmax operation. Overridden by
:attr:`attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different
mask for training and inference. The init arg is useful for cases
involving compilation/tracing, e.g. ONNX export.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -1833,6 +1930,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -1833,6 +1930,9 @@ class DotProductAttention(torch.nn.Module):
super().__init__() super().__init__()
self.qkv_format = qkv_format self.qkv_format = qkv_format
attn_mask_type = attn_mask_type.replace(",","_")
if attn_mask_type == "causal_padding":
attn_mask_type = "padding_causal"
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group) self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group self.tp_group = tp_group
...@@ -1900,9 +2000,11 @@ class DotProductAttention(torch.nn.Module): ...@@ -1900,9 +2000,11 @@ class DotProductAttention(torch.nn.Module):
# Instantiating three types since use of flash-attn and FusedAttention # Instantiating three types since use of flash-attn and FusedAttention
# might be ruled out due to forward inputs. # might be ruled out due to forward inputs.
if self.use_fused_attention: if self.use_fused_attention:
self.fused_attention = FusedAttention( self.fused_attention = FusedAttention(norm_factor,
norm_factor, **attn_kwargs, attention_type=attention_type,
attention_type=attention_type) layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs)
self.unfused_attention = UnfusedDotProductAttention( self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number) norm_factor, **attn_kwargs, layer_number=layer_number)
...@@ -2016,9 +2118,13 @@ class DotProductAttention(torch.nn.Module): ...@@ -2016,9 +2118,13 @@ class DotProductAttention(torch.nn.Module):
Key tensor. Key tensor.
value_layer : torch.Tensor value_layer : torch.Tensor
Value tensor. Value tensor.
attention_mask : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None` attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
Boolean tensor used to mask out softmax input when not using flash-attn. default = `None`. Boolean tensor(s) used to mask out attention softmax input.
Can be a tuple of 2 masks for cross attention with padding masks. It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be
a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
qkv_format: str, default = `None` qkv_format: str, default = `None`
If provided, overrides :attr:`qkv_format` from initialization. If provided, overrides :attr:`qkv_format` from initialization.
cu_seqlens_q: Optional[torch.Tensor], default = `None` cu_seqlens_q: Optional[torch.Tensor], default = `None`
...@@ -2027,17 +2133,19 @@ class DotProductAttention(torch.nn.Module): ...@@ -2027,17 +2133,19 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_kv: Optional[torch.Tensor], default = `None` cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`, Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
with shape [batch_size + 1] and dtype torch.int32. with shape [batch_size + 1] and dtype torch.int32.
attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `None` attn_mask_type: {`no_mask`, `padding`, `causal`, `padding,causal`, `causal,padding`,
type of attention mask passed into softmax operation. `arbitrary`}, default = `None`. Type of attention mask passed into
softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
checkpoint_core_attention : bool, default = `False` checkpoint_core_attention : bool, default = `False`
If true, forward activations for attention are recomputed If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until otherwise be occupied to store the forward activations until
backprop. backprop.
core_attention_bias_type: str, default = `no_bias` core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`} Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
core_attention_bias: Optional[torch.Tensor], default = `None` core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
It should be 'None' for 'no_bias' and 'alibi' bias types.
fast_zero_fill: bool, default = `True` fast_zero_fill: bool, default = `True`
Whether to use the fast path to set output tensors to 0 or not. Whether to use the fast path to set output tensors to 0 or not.
""" """
...@@ -2051,9 +2159,16 @@ class DotProductAttention(torch.nn.Module): ...@@ -2051,9 +2159,16 @@ class DotProductAttention(torch.nn.Module):
if attn_mask_type is None: if attn_mask_type is None:
attn_mask_type = self.attn_mask_type attn_mask_type = self.attn_mask_type
else:
attn_mask_type = attn_mask_type.replace(",","_")
if attn_mask_type == "causal_padding":
attn_mask_type = "padding_causal"
assert (attn_mask_type in AttnMaskTypes
), f"Attention mask type {attn_mask_type} is not supported!"
if qkv_format is None: if qkv_format is None:
qkv_format = self.qkv_format qkv_format = self.qkv_format
attn_mask_type, causal_mask = _unpack_attn_mask_type(attn_mask_type)
assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition and value_layer.shape[-2] == self.num_gqa_groups_per_partition
...@@ -2136,7 +2251,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -2136,7 +2251,7 @@ class DotProductAttention(torch.nn.Module):
use_flash_attention = False use_flash_attention = False
if (_flash_attn_2_1_plus if (_flash_attn_2_1_plus
and causal_mask and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv): and max_seqlen_q != max_seqlen_kv):
warnings.warn( warnings.warn(
"Disabling the use of FlashAttention since version 2.1+ has changed its behavior " "Disabling the use of FlashAttention since version 2.1+ has changed its behavior "
...@@ -2156,19 +2271,15 @@ class DotProductAttention(torch.nn.Module): ...@@ -2156,19 +2271,15 @@ class DotProductAttention(torch.nn.Module):
# Filter: Attention mask type. # Filter: Attention mask type.
# attn_mask_type(s) | supported backends # attn_mask_type(s) | supported backends
# ------------------------------------------------ # ------------------------------------------------
# no_mask | All
# padding | UnfusedDotProductAttention, FlashAttention, FusedAttention
# causal | All # causal | All
# padding | UnfusedDotProductAttention, FlashAttention # padding + causal | FlashAttention, FusedAttention
# arbitrary | UnfusedDotProductAttention # arbitrary | UnfusedDotProductAttention
# no_mask | All
# causal + padding | FlashAttention
# #
if attn_mask_type == "arbitrary": if attn_mask_type == "arbitrary":
use_flash_attention = False use_flash_attention = False
use_fused_attention = False use_fused_attention = False
elif attn_mask_type == "padding" and causal_mask:
assert use_flash_attention, "No attention backend available for causal + padding masks."
elif attn_mask_type == "padding":
use_fused_attention = False
if use_fused_attention: if use_fused_attention:
fused_attention_backend = tex.get_fused_attn_backend( fused_attention_backend = tex.get_fused_attn_backend(
...@@ -2178,23 +2289,28 @@ class DotProductAttention(torch.nn.Module): ...@@ -2178,23 +2289,28 @@ class DotProductAttention(torch.nn.Module):
AttnBiasType[core_attention_bias_type], AttnBiasType[core_attention_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
self.attention_dropout, self.attention_dropout,
max_seqlen_q, max_seqlen_kv, query_layer.shape[-2], # num_attn_heads
query_layer.shape[-1]) key_layer.shape[-2], # num_gqa_groups
max_seqlen_q,
max_seqlen_kv,
query_layer.shape[-1], # head_dim
)
# DPA does not support FP8; for FP8, use cpp_extensions modules directly # DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = (fused_attention_backend in is_backend_avail = (fused_attention_backend in
[FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]]) [FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]])
use_fused_attention = (use_fused_attention use_fused_attention = (use_fused_attention
and is_backend_avail and is_backend_avail)
and self.num_gqa_groups == self.num_attention_heads)
if (self.deterministic # Select FusedAttention on sm90 and FlashAttention on others for performance
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]): if (use_flash_attention
use_fused_attention = False and use_fused_attention
warnings.warn( and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]):
"Disabling usage of FusedAttention since the FusedAttention" if self.device_compute_capability == (9, 0):
"backend does not support deterministic exection." use_flash_attention = False
)
if use_flash_attention: if use_flash_attention:
if _NVTE_DEBUG:
print("[DotProductAttention]: using flash-attn",_flash_attn_version)
return self.flash_attention(query_layer, return self.flash_attention(query_layer,
key_layer, key_layer,
value_layer, value_layer,
...@@ -2212,6 +2328,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -2212,6 +2328,9 @@ class DotProductAttention(torch.nn.Module):
), "Context parallelism is only implemented with Flash Attention!" ), "Context parallelism is only implemented with Flash Attention!"
if use_fused_attention: if use_fused_attention:
if _NVTE_DEBUG:
print("[DotProductAttention]: using cuDNN fused attention (backend "
+ str(int(fused_attention_backend)) + ")")
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward(self.fused_attention, return self._checkpointed_attention_forward(self.fused_attention,
query_layer, query_layer,
...@@ -2221,6 +2340,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -2221,6 +2340,7 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_q = cu_seqlens_q, cu_seqlens_q = cu_seqlens_q,
cu_seqlens_kv = cu_seqlens_kv, cu_seqlens_kv = cu_seqlens_kv,
attn_mask_type = attn_mask_type, attn_mask_type = attn_mask_type,
attention_mask = attention_mask,
fused_attention_backend = fused_attention_backend, fused_attention_backend = fused_attention_backend,
core_attention_bias_type = core_attention_bias_type, core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias, core_attention_bias = core_attention_bias,
...@@ -2230,11 +2350,14 @@ class DotProductAttention(torch.nn.Module): ...@@ -2230,11 +2350,14 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_q = cu_seqlens_q, cu_seqlens_q = cu_seqlens_q,
cu_seqlens_kv = cu_seqlens_kv, cu_seqlens_kv = cu_seqlens_kv,
attn_mask_type = attn_mask_type, attn_mask_type = attn_mask_type,
attention_mask = attention_mask,
fused_attention_backend = fused_attention_backend, fused_attention_backend = fused_attention_backend,
core_attention_bias_type = core_attention_bias_type, core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias, core_attention_bias = core_attention_bias,
fast_zero_fill = fast_zero_fill) fast_zero_fill = fast_zero_fill)
if _NVTE_DEBUG:
print("[DotProductAttention]: using unfused DPA")
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.unfused_attention, self.unfused_attention,
...@@ -2267,8 +2390,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2267,8 +2390,8 @@ class MultiheadAttention(torch.nn.Module):
.. note:: .. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when Argument :attr:`attention_mask` in the `forward` call is only used when
:attr:`attn_mask_type` is set to `"causal"`. :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
Parameters Parameters
---------- ----------
...@@ -2295,7 +2418,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2295,7 +2418,8 @@ class MultiheadAttention(torch.nn.Module):
layer_number: int, default = `None` layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules are layer number of the current `TransformerLayer` when multiple such modules are
concatenated to form a transformer block. concatenated to form a transformer block.
attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal` attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal' 'arbitrary'},
default = `causal`
type of attention mask passed into softmax operation. Overridden by type of attention mask passed into softmax operation. Overridden by
:attr:`attn_mask_type` in the `forward` method. The forward :attr:`attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different arg is useful for dynamically changing mask types, e.g. a different
...@@ -2646,16 +2770,22 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2646,16 +2770,22 @@ class MultiheadAttention(torch.nn.Module):
.. note:: .. note::
Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type` Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
is set to `"causal"`. includes `"padding"` or `"arbitrary"`.
Parameters Parameters
---------- ----------
hidden_states : torch.Tensor hidden_states : torch.Tensor
Input tensor. Input tensor.
attention_mask : Optional[torch.Tensor], default = `None` attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
Boolean tensor used to mask out self-attention softmax input. default = `None`. Boolean tensor(s) used to mask out attention softmax input.
attn_mask_type: {'causal', 'padding', 'no_mask', arbitrary}, default = `None` It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be
a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
default = `None`
type of attention mask passed into softmax operation. type of attention mask passed into softmax operation.
encoder_output : Optional[torch.Tensor], default = `None` encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
...@@ -2682,9 +2812,10 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2682,9 +2812,10 @@ class MultiheadAttention(torch.nn.Module):
Embeddings for query and key tensors for applying rotary position Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied. embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias` core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`} Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
core_attention_bias: Optional[torch.Tensor], default = `None` core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
It should be 'None' for 'no_bias' and 'alibi' bias types.
fast_zero_fill: bool, default = `True` fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use. Whether to set output tensors to 0 or not before use.
""" """
...@@ -2693,10 +2824,11 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2693,10 +2824,11 @@ class MultiheadAttention(torch.nn.Module):
if attn_mask_type is None: if attn_mask_type is None:
attn_mask_type = self.attn_mask_type attn_mask_type = self.attn_mask_type
if attn_mask_type == "padding" and attention_mask is not None: if "padding" in attn_mask_type and attention_mask is not None:
assert ( for i,_ in enumerate(attention_mask):
attention_mask.dtype == torch.bool assert (
), "Attention mask must be a boolean tensor" attention_mask[i].dtype == torch.bool
), "Attention mask must be in boolean type!"
assert (core_attention_bias_type in AttnBiasTypes assert (core_attention_bias_type in AttnBiasTypes
), f"core_attention_bias_type {core_attention_bias_type} is not supported!" ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
......
...@@ -22,11 +22,11 @@ TE_DType = { ...@@ -22,11 +22,11 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16, torch.bfloat16: tex.DType.kBFloat16,
} }
AttnMaskTypes = ("causal", "padding", "arbitrary", "no_mask") AttnMaskTypes = ("causal", "padding", "padding_causal", "arbitrary", "no_mask")
AttnTypes = ("self", "cross") AttnTypes = ("self", "cross")
AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias") AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias", "alibi")
QKVLayouts = ( QKVLayouts = (
"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", "sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd",
......
...@@ -33,9 +33,6 @@ TORCH_DType = { ...@@ -33,9 +33,6 @@ TORCH_DType = {
} }
QKVLayout = { QKVLayout = {
"not_interleaved": NVTE_QKV_Layout.NVTE_NOT_INTERLEAVED,
"qkv_interleaved": NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED,
"kv_interleaved": NVTE_QKV_Layout.NVTE_KV_INTERLEAVED,
"sb3hd": NVTE_QKV_Layout.NVTE_SB3HD, "sb3hd": NVTE_QKV_Layout.NVTE_SB3HD,
"sbh3d": NVTE_QKV_Layout.NVTE_SBH3D, "sbh3d": NVTE_QKV_Layout.NVTE_SBH3D,
"sbhd_sb2hd": NVTE_QKV_Layout.NVTE_SBHD_SB2HD, "sbhd_sb2hd": NVTE_QKV_Layout.NVTE_SBHD_SB2HD,
...@@ -57,12 +54,14 @@ AttnBiasType = { ...@@ -57,12 +54,14 @@ AttnBiasType = {
"no_bias": NVTE_Bias_Type.NVTE_NO_BIAS, "no_bias": NVTE_Bias_Type.NVTE_NO_BIAS,
"pre_scale_bias": NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS, "pre_scale_bias": NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS,
"post_scale_bias": NVTE_Bias_Type.NVTE_POST_SCALE_BIAS, "post_scale_bias": NVTE_Bias_Type.NVTE_POST_SCALE_BIAS,
"alibi": NVTE_Bias_Type.NVTE_ALIBI,
} }
AttnMaskType = { AttnMaskType = {
"no_mask": NVTE_Mask_Type.NVTE_NO_MASK, "no_mask": NVTE_Mask_Type.NVTE_NO_MASK,
"padding": NVTE_Mask_Type.NVTE_PADDING_MASK, "padding": NVTE_Mask_Type.NVTE_PADDING_MASK,
"causal": NVTE_Mask_Type.NVTE_CAUSAL_MASK, "causal": NVTE_Mask_Type.NVTE_CAUSAL_MASK,
"padding_causal": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK,
} }
FusedAttnBackend = { FusedAttnBackend = {
...@@ -76,84 +75,6 @@ BACKEND_F16m512_FP8_THREADS_PER_CTA = 128 ...@@ -76,84 +75,6 @@ BACKEND_F16m512_FP8_THREADS_PER_CTA = 128
BACKEND_F16arb_ELTS_PER_THREADS = 16 BACKEND_F16arb_ELTS_PER_THREADS = 16
def check_tensor(x: torch.Tensor):
"""Check tensor properties."""
assert (x.is_cuda and x.is_contiguous()
), "Tensor should be a GPU tensor and contiguous."
def check_qkv(qkv: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties."""
check_tensor(qkv)
assert (qkv.dtype is dtype
and qkv.dim() == 4
and qkv.shape[1] == 3
), """QKV should be in [total_seqs, 3, num_heads, head_dim] shape
and {dtype} dtype."""
def check_q(q: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties."""
check_tensor(q)
assert (q.dtype is dtype
and q.dim() == 3
), """Q should be in [total_seqs, num_heads, head_dim] shape
and {dtype} dtype."""
def check_kv(kv: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties."""
check_tensor(kv)
assert (kv.dtype is dtype
and kv.dim() == 4
and kv.shape[1] == 2
), """KV should be in [total_seqs, 2, num_heads, head_dim] shape
and {dtype} dtype."""
def check_o(o: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties."""
check_tensor(o)
assert (o.dtype is dtype
and o.dim() == 3
), """O and dO should be in [total_seqs, num_heads, head_dim] shape
and {dtype} dtype."""
def check_stats(stats: torch.Tensor, b: int, h: int, s: int):
"""Check tensor properties."""
check_tensor(stats)
assert (stats.dtype is torch.float32
and stats.dim() == 4
and stats.shape == torch.Size([b, h, s, 1])
), """M and ZInv should be in [batch_size, num_heads, max_seqlen_q, 1]
shape and float32 dtype."""
def check_cu_seqlens(cu_seqlens: torch.Tensor):
"""Check tensor properties."""
check_tensor(cu_seqlens)
assert (cu_seqlens.dtype is torch.int32
and cu_seqlens.dim() == 1
), """cu_seqlens should be in [batch_size +1] shape and int32 dtype."""
def check_scalar(scalar: torch.Tensor):
"""Check tensor properties."""
check_tensor(scalar)
assert (scalar.dtype is torch.float32
and scalar.dim() <= 1
and scalar.numel() == 1
), "amax/scale/descale tensors should be scalars in float32 dtype."
def check_rng_state(rng_state: torch.Tensor):
"""Check tensor properties."""
check_tensor(rng_state)
assert (rng_state.dtype is torch.int64
and rng_state.numel() == 2
), "rng_state should be [seed, offset] and in int64 dtype."
def fused_attn_fwd_qkvpacked( def fused_attn_fwd_qkvpacked(
is_training: bool, is_training: bool,
max_seqlen: int, max_seqlen: int,
...@@ -170,7 +91,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -170,7 +91,7 @@ def fused_attn_fwd_qkvpacked(
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
qkv_layout: str = "qkv_interleaved", qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
rng_gen: torch.Generator = None, rng_gen: torch.Generator = None,
...@@ -188,8 +109,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -188,8 +109,7 @@ def fused_attn_fwd_qkvpacked(
cu_seqlens: torch.Tensor cu_seqlens: torch.Tensor
cumulative sequence lengths for QKV; shape [batch_size + 1] cumulative sequence lengths for QKV; shape [batch_size + 1]
qkv: torch.Tensor qkv: torch.Tensor
input tensor QKV; input tensor QKV; shape 3hd or h3d (see `qkv_layout` for details)
shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
qkv_dtype: tex.DType qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype data type of QKV; in tex.DType, not torch.dtype
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend: tex.NVTE_Fused_Attn_Backend
...@@ -216,12 +136,12 @@ def fused_attn_fwd_qkvpacked( ...@@ -216,12 +136,12 @@ def fused_attn_fwd_qkvpacked(
fast_zero_fill: bool, default = True fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the fast filling method; if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "qkv_interleaved" qkv_layout: str, default = "sbh3d"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} layout of QKV; {"sbh3d", "sb3hd", "bsh3d", "bs3hd", "th3d", "t3hd"}
attn_bias_type: str, default = "no_bias" attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
rng_gen: torch.Generator, default = None rng_gen: torch.Generator, default = None
random number generator; random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
...@@ -230,7 +150,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -230,7 +150,7 @@ def fused_attn_fwd_qkvpacked(
---------- ----------
o: torch.Tensor o: torch.Tensor
output tensor O, of the attention calculation; same data type as QKV; output tensor O, of the attention calculation; same data type as QKV;
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors used for the backward; auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state] if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state]
...@@ -257,21 +177,14 @@ def fused_attn_fwd_qkvpacked( ...@@ -257,21 +177,14 @@ def fused_attn_fwd_qkvpacked(
[seed, offset], dtype uint64 [seed, offset], dtype uint64
""" """
check_cu_seqlens(cu_seqlens)
b = cu_seqlens.numel() - 1
qkv_type = TORCH_DType[qkv_dtype]
check_qkv(qkv, qkv_type)
total_seqs = qkv.size(0)
h = qkv.size(2)
d = qkv.size(3)
if attn_scale is None: if attn_scale is None:
d = qkv.size(-1)
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
if attn_bias_type != "no_bias": if attn_bias_type not in ["no_bias", "alibi"]:
assert (attn_bias is not None assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias." ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi."
h = qkv.size(2) if 'h3d' in qkv_layout else qkv.size(3)
assert (attn_bias.shape == torch.Size([1, h, max_seqlen, max_seqlen]) assert (attn_bias.shape == torch.Size([1, h, max_seqlen, max_seqlen])
), "attn_bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." ), "attn_bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (attn_bias.dtype == qkv.dtype assert (attn_bias.dtype == qkv.dtype
...@@ -304,16 +217,10 @@ def fused_attn_fwd_qkvpacked( ...@@ -304,16 +217,10 @@ def fused_attn_fwd_qkvpacked(
), "amax_s is required as an input for FP8 fused attention." ), "amax_s is required as an input for FP8 fused attention."
assert (amax_o is not None assert (amax_o is not None
), "amax_o is required as an input for FP8 fused attention." ), "amax_o is required as an input for FP8 fused attention."
check_scalar(d_scale_qkv)
check_scalar(q_scale_s)
check_scalar(q_scale_o)
check_scalar(amax_s)
check_scalar(amax_o)
# execute kernel # execute kernel
output_tensors = tex.fused_attn_fwd_qkvpacked( output_tensors = tex.fused_attn_fwd_qkvpacked(
b, max_seqlen, total_seqs, h, d, max_seqlen, is_training, attn_scale, dropout, fast_zero_fill,
is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, qkv_dtype, cu_seqlens, qkv, qkv_dtype,
d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias,
...@@ -345,7 +252,7 @@ def fused_attn_bwd_qkvpacked( ...@@ -345,7 +252,7 @@ def fused_attn_bwd_qkvpacked(
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
qkv_layout: str = "qkv_interleaved", qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
...@@ -359,14 +266,13 @@ def fused_attn_bwd_qkvpacked( ...@@ -359,14 +266,13 @@ def fused_attn_bwd_qkvpacked(
cu_seqlens: torch.Tensor cu_seqlens: torch.Tensor
cumulative sequence lengths for QKV; shape [batch_size + 1] cumulative sequence lengths for QKV; shape [batch_size + 1]
qkv: torch.Tensor qkv: torch.Tensor
input tensor QKV; input tensor QKV; shape 3hd or h3d (see `qkv_layout` for details)
shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
o: torch.Tensor o: torch.Tensor
input tensor O (output of forward); input tensor O (output of forward);
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
d_o: torch.Tensor d_o: torch.Tensor
input tensor dO (gradient of O); input tensor dO (gradient of O);
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
qkv_dtype: tex.DType qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype data type of QKV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors: List[torch.Tensor]
...@@ -401,12 +307,12 @@ def fused_attn_bwd_qkvpacked( ...@@ -401,12 +307,12 @@ def fused_attn_bwd_qkvpacked(
fast_zero_fill: bool, default = True fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the fast filling method; if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "qkv_interleaved" qkv_layout: str, default = "sbh3d"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} layout of QKV; {"sbh3d", "sb3hd", "bsh3d", "bs3hd", "th3d", "t3hd"}
attn_bias_type: str, default = "no_bias" attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
Returns Returns
---------- ----------
...@@ -417,18 +323,8 @@ def fused_attn_bwd_qkvpacked( ...@@ -417,18 +323,8 @@ def fused_attn_bwd_qkvpacked(
or "post_scale_bias"; same data type and shape as Bias or "post_scale_bias"; same data type and shape as Bias
""" """
check_cu_seqlens(cu_seqlens)
b = cu_seqlens.numel() - 1
qkv_type = TORCH_DType[qkv_dtype]
check_qkv(qkv, qkv_type)
check_o(o, qkv_type)
check_o(d_o, qkv_type)
total_seqs = qkv.size(0)
h = qkv.size(2)
d = qkv.size(3)
if attn_scale is None: if attn_scale is None:
d = qkv.size(-1)
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"] assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
...@@ -437,8 +333,6 @@ def fused_attn_bwd_qkvpacked( ...@@ -437,8 +333,6 @@ def fused_attn_bwd_qkvpacked(
if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]:
assert (len(aux_ctx_tensors) >= 1 assert (len(aux_ctx_tensors) >= 1
), "aux_ctx_tensors must contain rng_state as its last element." ), "aux_ctx_tensors must contain rng_state as its last element."
rng_state = aux_ctx_tensors[-1]
check_rng_state(rng_state)
if fused_attention_backend == FusedAttnBackend["FP8"]: if fused_attention_backend == FusedAttnBackend["FP8"]:
assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention." assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention."
...@@ -452,34 +346,17 @@ def fused_attn_bwd_qkvpacked( ...@@ -452,34 +346,17 @@ def fused_attn_bwd_qkvpacked(
assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention." assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention."
assert (len(aux_ctx_tensors) == 3 assert (len(aux_ctx_tensors) == 3
), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention."
check_scalar(d_scale_qkv)
check_scalar(d_scale_s)
check_scalar(d_scale_o)
check_scalar(d_scale_do)
check_scalar(q_scale_s)
check_scalar(q_scale_dp)
check_scalar(q_scale_dqkv)
check_scalar(amax_dp)
check_scalar(amax_dqkv)
m, z_inv = aux_ctx_tensors[:2]
check_stats(m, b, h, max_seqlen)
check_stats(z_inv, b, h, max_seqlen)
# execute kernel # execute kernel
output_tensors = tex.fused_attn_bwd_qkvpacked( output_tensors = tex.fused_attn_bwd_qkvpacked(
b, max_seqlen, total_seqs, h, d, max_seqlen, attn_scale, dropout, fast_zero_fill,
attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, o, d_o, qkv_dtype, aux_ctx_tensors, cu_seqlens, qkv, o, d_o, qkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
) )
if attn_bias_type == "no_bias": return output_tensors
# return d_qkv when attn_bias_type is no_bias
return output_tensors
# otherwise return (d_qkv, d_bias)
return output_tensors[0], output_tensors[1]
def fused_attn_fwd_kvpacked( def fused_attn_fwd_kvpacked(
...@@ -501,7 +378,7 @@ def fused_attn_fwd_kvpacked( ...@@ -501,7 +378,7 @@ def fused_attn_fwd_kvpacked(
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
qkv_layout: str = "kv_interleaved", qkv_layout: str = "sbhd_sbh2d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
rng_gen: torch.Generator = None, rng_gen: torch.Generator = None,
...@@ -524,12 +401,9 @@ def fused_attn_fwd_kvpacked( ...@@ -524,12 +401,9 @@ def fused_attn_fwd_kvpacked(
cu_seqlens_kv: torch.Tensor cu_seqlens_kv: torch.Tensor
cumulative sequence lengths for KV; shape [batch_size + 1] cumulative sequence lengths for KV; shape [batch_size + 1]
q: torch.Tensor q: torch.Tensor
input tensor Q; input tensor Q; shape thd, sbhd or bshd (see `qkv_layout` for details)
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
kv: torch.Tensor kv: torch.Tensor
packed input tensor KV; packed input tensor KV; shape 2hd or h2d (see `qkv_layout` for details)
shape [total_seqs_kv, 2, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1]
qkv_dtype: tex.DType qkv_dtype: tex.DType
data type of Q and KV; in tex.DType, not torch.dtype data type of Q and KV; in tex.DType, not torch.dtype
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend: tex.NVTE_Fused_Attn_Backend
...@@ -556,12 +430,13 @@ def fused_attn_fwd_kvpacked( ...@@ -556,12 +430,13 @@ def fused_attn_fwd_kvpacked(
fast_zero_fill: bool, default = True fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the fast filling method; if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "kv_interleaved" qkv_layout: str, default = "sbhd_sbh2d"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} layout of QKV;
{"sbhd_sbh2d", "sbhd_sb2hd", "bshd_bsh2d", "bshd_bs2hd", "thd_th2d", "thd_t2hd"}
attn_bias_type: str, default = "no_bias" attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
rng_gen: torch.Generator, default = None rng_gen: torch.Generator, default = None
random number generator; random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
...@@ -570,7 +445,7 @@ def fused_attn_fwd_kvpacked( ...@@ -570,7 +445,7 @@ def fused_attn_fwd_kvpacked(
---------- ----------
o: torch.Tensor o: torch.Tensor
output tensor O, of the attention calculation; same data type as QKV; output tensor O, of the attention calculation; same data type as QKV;
shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors used for the backward; auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state] if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state]
...@@ -597,29 +472,14 @@ def fused_attn_fwd_kvpacked( ...@@ -597,29 +472,14 @@ def fused_attn_fwd_kvpacked(
[seed, offset], dtype uint64 [seed, offset], dtype uint64
""" """
check_cu_seqlens(cu_seqlens_q)
check_cu_seqlens(cu_seqlens_kv)
assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel()
), "cu_seqlens_q and cu_seqlens_kv must have the same length."
b = cu_seqlens_q.numel() - 1
qkv_type = TORCH_DType[qkv_dtype]
check_q(q, qkv_type)
check_kv(kv, qkv_type)
assert (q.size(1) == kv.size(2)
and q.size(2) == kv.size(3)
), "Q and KV must have the same num_heads and head_dim."
total_seqs_q = q.size(0)
total_seqs_kv = kv.size(0)
h = q.size(1)
d = q.size(2)
if attn_scale is None: if attn_scale is None:
d = q.size(-1)
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
if attn_bias_type != "no_bias": if attn_bias_type not in ["no_bias", "alibi"]:
assert (attn_bias is not None assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias." ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi."
h = q.size(2)
assert (attn_bias.shape == torch.Size([1, h, max_seqlen_q, max_seqlen_kv]) assert (attn_bias.shape == torch.Size([1, h, max_seqlen_q, max_seqlen_kv])
), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape." ), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (attn_bias.dtype == q.dtype assert (attn_bias.dtype == q.dtype
...@@ -644,8 +504,7 @@ def fused_attn_fwd_kvpacked( ...@@ -644,8 +504,7 @@ def fused_attn_fwd_kvpacked(
# execute kernel # execute kernel
output_tensors = tex.fused_attn_fwd_kvpacked( output_tensors = tex.fused_attn_fwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d, max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill,
is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype, cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype,
d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o,
...@@ -680,7 +539,7 @@ def fused_attn_bwd_kvpacked( ...@@ -680,7 +539,7 @@ def fused_attn_bwd_kvpacked(
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
qkv_layout: str = "kv_interleaved", qkv_layout: str = "sbhd_sbh2d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
...@@ -699,18 +558,15 @@ def fused_attn_bwd_kvpacked( ...@@ -699,18 +558,15 @@ def fused_attn_bwd_kvpacked(
cu_seqlens_kv: torch.Tensor cu_seqlens_kv: torch.Tensor
cumulative sequence lengths for KV; shape [batch_size + 1] cumulative sequence lengths for KV; shape [batch_size + 1]
q: torch.Tensor q: torch.Tensor
input tensor Q; input tensor Q; shape thd, sbhd or bshd (see `qkv_layout` for details)
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
kv: torch.Tensor kv: torch.Tensor
packed input tensor KV; packed input tensor KV; shape h2d or 2hd (see `qkv_layout` for details)
shape [total_seqs_kv, 2, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1]
o: torch.Tensor o: torch.Tensor
input tensor O (output of forward); input tensor O (output of forward);
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
d_o: torch.Tensor d_o: torch.Tensor
input tensor dO (gradient of O); input tensor dO (gradient of O);
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
qkv_dtype: tex.DType qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype data type of QKV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors: List[torch.Tensor]
...@@ -746,12 +602,13 @@ def fused_attn_bwd_kvpacked( ...@@ -746,12 +602,13 @@ def fused_attn_bwd_kvpacked(
fast_zero_fill: bool, default = True fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the fast filling method; if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "kv_interleaved" qkv_layout: str, default = "sbhd_sbh2d"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} layout of QKV;
{"sbhd_sbh2d", "sbhd_sb2hd", "bshd_bsh2d", "bshd_bs2hd", "thd_th2d", "thd_t2hd"}
attn_bias_type: str, default = "no_bias" attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
Returns Returns
---------- ----------
...@@ -764,26 +621,8 @@ def fused_attn_bwd_kvpacked( ...@@ -764,26 +621,8 @@ def fused_attn_bwd_kvpacked(
or "post_scale_bias"; same data type and shape as Bias or "post_scale_bias"; same data type and shape as Bias
""" """
check_cu_seqlens(cu_seqlens_q)
check_cu_seqlens(cu_seqlens_kv)
assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel()
), "cu_seqlens_q and cu_seqlens_kv must have the same length."
b = cu_seqlens_q.numel() - 1
qkv_type = TORCH_DType[qkv_dtype]
check_q(q, qkv_type)
check_kv(kv, qkv_type)
check_o(o, qkv_type)
check_o(d_o, qkv_type)
assert (q.size(1) == kv.size(2)
and q.size(2) == kv.size(3)
), "Q and KV must have the same num_heads and head_dim."
total_seqs_q = q.size(0)
total_seqs_kv = q.size(0)
h = q.size(1)
d = q.size(2)
if attn_scale is None: if attn_scale is None:
d = q.size(-1)
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"] assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
...@@ -792,8 +631,6 @@ def fused_attn_bwd_kvpacked( ...@@ -792,8 +631,6 @@ def fused_attn_bwd_kvpacked(
if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]:
assert (len(aux_ctx_tensors) >= 1 assert (len(aux_ctx_tensors) >= 1
), "aux_ctx_tensors must contain rng_state as its last element." ), "aux_ctx_tensors must contain rng_state as its last element."
rng_state = aux_ctx_tensors[-1]
check_rng_state(rng_state)
if fused_attention_backend == FusedAttnBackend["FP8"]: if fused_attention_backend == FusedAttnBackend["FP8"]:
assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention." assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention."
...@@ -807,34 +644,18 @@ def fused_attn_bwd_kvpacked( ...@@ -807,34 +644,18 @@ def fused_attn_bwd_kvpacked(
assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention." assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention."
assert (len(aux_ctx_tensors) == 3 assert (len(aux_ctx_tensors) == 3
), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention."
check_scalar(d_scale_qkv)
check_scalar(d_scale_s)
check_scalar(d_scale_o)
check_scalar(d_scale_do)
check_scalar(q_scale_s)
check_scalar(q_scale_dp)
check_scalar(q_scale_dqkv)
check_scalar(amax_dp)
check_scalar(amax_dqkv)
m, z_inv = aux_ctx_tensors[:2]
check_stats(m, b, h, max_seqlen_q)
check_stats(z_inv, b, h, max_seqlen_q)
# execute kernel # execute kernel
output_tensors = tex.fused_attn_bwd_kvpacked( output_tensors = tex.fused_attn_bwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d, max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill,
attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, aux_ctx_tensors, cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
) )
if attn_bias_type == "no_bias": return output_tensors
# return (d_q, d_kv) when attn_bias_type is no_bias
return output_tensors
# otherwise return (d_q, d_kv), d_bias
return output_tensors[:2], output_tensors[2]
def fused_attn_fwd( def fused_attn_fwd(
is_training: bool, is_training: bool,
...@@ -881,23 +702,11 @@ def fused_attn_fwd( ...@@ -881,23 +702,11 @@ def fused_attn_fwd(
cu_seqlens_kv: torch.Tensor cu_seqlens_kv: torch.Tensor
cumulative sequence lengths for K and V; shape [batch_size + 1] cumulative sequence lengths for K and V; shape [batch_size + 1]
q: torch.Tensor q: torch.Tensor
input tensor Q; input tensor Q; shape sbhd, bshd or thd (see `qkv_layout` for details)
shape [total_seqs_q, num_heads, head_dim],
where total_seqs_q = cu_seqlens_q[-1],
or [batch_size, seqlen_q, num_heads, head_dim],
or [seqlen_q, batch_size, num_heads, head_dim]
k: torch.Tensor k: torch.Tensor
input tensor K; input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details)
shape [total_seqs_kv, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1],
or [batch_size, seqlen_kv, num_heads, head_dim],
or [seqlen_kv, batch_size, num_heads, head_dim]
v: torch.Tensor v: torch.Tensor
input tensor V; input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details)
shape [total_seqs_kv, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1],
or [batch_size, seqlen_kv, num_heads, head_dim],
or [seqlen_kv, batch_size, num_heads, head_dim]
qkv_dtype: tex.DType qkv_dtype: tex.DType
data type of Q, K and V; in tex.DType, not torch.dtype data type of Q, K and V; in tex.DType, not torch.dtype
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend: tex.NVTE_Fused_Attn_Backend
...@@ -930,9 +739,9 @@ def fused_attn_fwd( ...@@ -930,9 +739,9 @@ def fused_attn_fwd(
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"}
attn_bias_type: str, default = "no_bias" attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
rng_gen: torch.Generator, default = None rng_gen: torch.Generator, default = None
random number generator; random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
...@@ -968,19 +777,14 @@ def fused_attn_fwd( ...@@ -968,19 +777,14 @@ def fused_attn_fwd(
[seed, offset], dtype uint64 [seed, offset], dtype uint64
""" """
check_cu_seqlens(cu_seqlens_q)
check_cu_seqlens(cu_seqlens_kv)
assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel()
), "cu_seqlens_q and cu_seqlens_kv must have the same length."
h = q.shape[-2]
d = q.shape[-1]
if attn_scale is None: if attn_scale is None:
d = q.size(-1)
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
if attn_bias_type != "no_bias": if attn_bias_type not in ["no_bias", "alibi"]:
assert (attn_bias is not None assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias." ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi."
h = q.size(2)
assert (attn_bias.shape == torch.Size([1, h, max_seqlen_q, max_seqlen_kv]) assert (attn_bias.shape == torch.Size([1, h, max_seqlen_q, max_seqlen_kv])
), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape." ), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (attn_bias.dtype == q.dtype assert (attn_bias.dtype == q.dtype
...@@ -1061,23 +865,11 @@ def fused_attn_bwd( ...@@ -1061,23 +865,11 @@ def fused_attn_bwd(
cu_seqlens_kv: torch.Tensor cu_seqlens_kv: torch.Tensor
cumulative sequence lengths for K and V; shape [batch_size + 1] cumulative sequence lengths for K and V; shape [batch_size + 1]
q: torch.Tensor q: torch.Tensor
input tensor Q; input tensor Q; shape sbhd, bshd or thd (see `qkv_layout` for details)
shape [total_seqs_q, num_heads, head_dim],
where total_seqs_q = cu_seqlens_q[-1],
or [batch_size, seqlen_q, num_heads, head_dim],
or [seqlen_q, batch_size, num_heads, head_dim]
k: torch.Tensor k: torch.Tensor
input tensor K; input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details)
shape [total_seqs_kv, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1],
or [batch_size, seqlen_kv, num_heads, head_dim],
or [seqlen_kv, batch_size, num_heads, head_dim]
v: torch.Tensor v: torch.Tensor
input tensor V; input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details)
shape [total_seqs_kv, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1],
or [batch_size, seqlen_kv, num_heads, head_dim],
or [seqlen_kv, batch_size, num_heads, head_dim]
o: torch.Tensor o: torch.Tensor
input tensor O (output of forward); same data type as Q, K and V; input tensor O (output of forward); same data type as Q, K and V;
same shape as Q same shape as Q
...@@ -1125,9 +917,9 @@ def fused_attn_bwd( ...@@ -1125,9 +917,9 @@ def fused_attn_bwd(
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"}
attn_bias_type: str, default = "no_bias" attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
Returns Returns
---------- ----------
...@@ -1142,15 +934,8 @@ def fused_attn_bwd( ...@@ -1142,15 +934,8 @@ def fused_attn_bwd(
or "post_scale_bias"; same data type and shape as Bias or "post_scale_bias"; same data type and shape as Bias
""" """
check_cu_seqlens(cu_seqlens_q)
check_cu_seqlens(cu_seqlens_kv)
assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel()
), "cu_seqlens_q and cu_seqlens_kv must have the same length."
b = cu_seqlens_q.numel() - 1
h = q.shape[-2]
d = q.shape[-1]
if attn_scale is None: if attn_scale is None:
d = q.size(-1)
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"] assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
...@@ -1159,8 +944,6 @@ def fused_attn_bwd( ...@@ -1159,8 +944,6 @@ def fused_attn_bwd(
if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]:
assert (len(aux_ctx_tensors) >= 1 assert (len(aux_ctx_tensors) >= 1
), "aux_ctx_tensors must contain rng_state as its last element." ), "aux_ctx_tensors must contain rng_state as its last element."
rng_state = aux_ctx_tensors[-1]
check_rng_state(rng_state)
if fused_attention_backend == FusedAttnBackend["FP8"]: if fused_attention_backend == FusedAttnBackend["FP8"]:
assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention." assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention."
...@@ -1174,18 +957,6 @@ def fused_attn_bwd( ...@@ -1174,18 +957,6 @@ def fused_attn_bwd(
assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention." assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention."
assert (len(aux_ctx_tensors) == 3 assert (len(aux_ctx_tensors) == 3
), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention."
check_scalar(d_scale_qkv)
check_scalar(d_scale_s)
check_scalar(d_scale_o)
check_scalar(d_scale_do)
check_scalar(q_scale_s)
check_scalar(q_scale_dp)
check_scalar(q_scale_dqkv)
check_scalar(amax_dp)
check_scalar(amax_dqkv)
m, z_inv = aux_ctx_tensors[:2]
check_stats(m, b, h, max_seqlen_q)
check_stats(z_inv, b, h, max_seqlen_q)
# execute kernel # execute kernel
output_tensors = tex.fused_attn_bwd( output_tensors = tex.fused_attn_bwd(
...@@ -1196,4 +967,4 @@ def fused_attn_bwd( ...@@ -1196,4 +967,4 @@ def fused_attn_bwd(
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
) )
return tuple(output_tensors) return output_tensors
...@@ -13,12 +13,13 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -13,12 +13,13 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_QKV_Layout qkv_layout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t max_seqlen_q, float p_dropout,
size_t max_seqlen_kv, size_t head_dim); size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv,
size_t head_dim);
std::vector<at::Tensor> fused_attn_fwd_qkvpacked( std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs, size_t max_seqlen, bool is_training,
size_t h, size_t d, bool is_training,
float attn_scale, float p_dropout, bool set_zero, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Bias_Type bias_type,
...@@ -36,8 +37,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -36,8 +37,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t rng_elts_per_thread); size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked( std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs, size_t max_seqlen, float attn_scale,
size_t h, size_t d, float attn_scale,
float p_dropout, bool set_zero, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Bias_Type bias_type,
...@@ -59,9 +59,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -59,9 +59,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
c10::optional<at::Tensor> amax_dQKV); c10::optional<at::Tensor> amax_dQKV);
std::vector<at::Tensor> fused_attn_fwd_kvpacked( std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d, bool is_training,
float attn_scale, float p_dropout, bool set_zero, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Bias_Type bias_type,
...@@ -81,10 +79,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -81,10 +79,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t rng_elts_per_thread); size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_kvpacked( std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv, float attn_scale, float p_dropout, bool set_zero,
size_t h, size_t d, float attn_scale,
float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Mask_Type attn_mask_type,
......
...@@ -16,13 +16,16 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -16,13 +16,16 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_QKV_Layout qkv_layout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t max_seqlen_q, float p_dropout,
size_t max_seqlen_kv, size_t head_dim) { size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv,
size_t head_dim) {
NVTE_Fused_Attn_Backend fused_attention_backend = NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend( nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype),
qkv_layout, bias_type, attn_mask_type, qkv_layout, bias_type, attn_mask_type, p_dropout,
p_dropout, max_seqlen_q, max_seqlen_kv, head_dim); num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv, head_dim);
return fused_attention_backend; return fused_attention_backend;
} }
...@@ -87,9 +90,8 @@ at::PhiloxCudaState init_philox_state( ...@@ -87,9 +90,8 @@ at::PhiloxCudaState init_philox_state(
// fused attention FWD with packed QKV // fused attention FWD with packed QKV
std::vector<at::Tensor> fused_attn_fwd_qkvpacked( std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs, size_t max_seqlen, bool is_training, float attn_scale,
size_t h, size_t d, float p_dropout, bool set_zero,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor QKV,
...@@ -104,16 +106,27 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -104,16 +106,27 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t rng_elts_per_thread) { size_t rng_elts_per_thread) {
using namespace transformer_engine; using namespace transformer_engine;
auto qkv_sizes = QKV.sizes().vec();
std::vector<size_t> qkv_shape{qkv_sizes.begin(), qkv_sizes.end()};
std::vector<size_t> q_shape;
for (auto i : qkv_shape) {
if (i != 3) {
q_shape.push_back(i);
}
}
std::vector<int64_t> o_shape{q_shape.begin(), q_shape.end()};
// create output tensor O // create output tensor O
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto O = torch::empty({static_cast<int64_t>(total_seqs), auto O = torch::empty(o_shape, options);
static_cast<int64_t>(h), static_cast<int64_t>(d)}, options);
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens; TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
if (set_zero && (h * d % block_size == 0)) { auto h = q_shape[q_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0)) {
mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else { } else {
O.fill_(0); O.fill_(0);
...@@ -123,32 +136,34 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -123,32 +136,34 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O";
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
} }
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
at::Tensor descale_S = torch::empty_like(scale_S.value()); at::Tensor descale_S = torch::empty_like(scale_S.value());
te_S = makeTransformerEngineTensor(nullptr, {0}, te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_S.value().data_ptr(), DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.data_ptr()); scale_S.value().data_ptr(), descale_S.data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16 // BF16 or FP16
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0}, te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr); DType::kFloat32, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
} else { } else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
} }
if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
auto bias_shape = Bias.value().sizes().vec(); auto bias_sizes = Bias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()}; std::vector<size_t> bias_shape{bias_sizes.begin(), bias_sizes.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape, te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape,
DType::kFloat32, nullptr, nullptr, nullptr); DType::kFloat32, nullptr, nullptr, nullptr);
} }
te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1}, auto cu_seqlens_sizes = cu_seqlens.sizes().vec();
std::vector<size_t> cu_seqlens_shape{cu_seqlens_sizes.begin(), cu_seqlens_sizes.end()};
te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
// extract random number generator seed and offset // extract random number generator seed and offset
...@@ -196,8 +211,18 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -196,8 +211,18 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
// allocate memory for nvte_aux_tensor_pack.tensors // allocate memory for nvte_aux_tensor_pack.tensors
at::Tensor output_tensor; at::Tensor output_tensor;
if (nvte_aux_tensor_pack.size >= 2) { if (nvte_aux_tensor_pack.size >= 2) {
output_tensor = (i < nvte_aux_tensor_pack.size-1) if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state; if (i < nvte_aux_tensor_pack.size - 2) {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
} else if (i == nvte_aux_tensor_pack.size - 2) {
output_tensor = rng_state;
} else if (i == nvte_aux_tensor_pack.size - 1) {
output_tensor = Bias.value();
}
} else {
output_tensor = (i < nvte_aux_tensor_pack.size-1)
? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state;
}
} else { } else {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
} }
...@@ -229,9 +254,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -229,9 +254,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
// fused attention BWD with packed QKV // fused attention BWD with packed QKV
std::vector<at::Tensor> fused_attn_bwd_qkvpacked( std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs, size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor QKV,
...@@ -250,12 +273,22 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -250,12 +273,22 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
c10::optional<at::Tensor> amax_dQKV) { c10::optional<at::Tensor> amax_dQKV) {
using namespace transformer_engine; using namespace transformer_engine;
auto qkv_sizes = QKV.sizes().vec();
std::vector<size_t> qkv_shape{qkv_sizes.begin(), qkv_sizes.end()};
std::vector<size_t> q_shape;
for (auto i : qkv_shape) {
if (i != 3) {
q_shape.push_back(i);
}
}
auto h = q_shape[q_shape.size() - 2];
// create output tensor dQKV // create output tensor dQKV
at::Tensor dQKV = torch::empty_like(QKV); at::Tensor dQKV = torch::empty_like(QKV);
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias; at::Tensor dBias;
TensorWrapper te_dBias; TensorWrapper te_dBias;
if (bias_type != NVTE_NO_BIAS) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
dBias = torch::empty({1, static_cast<int64_t>(h), dBias = torch::empty({1, static_cast<int64_t>(h),
static_cast<int64_t>(max_seqlen), static_cast<int64_t>(max_seqlen),
static_cast<int64_t>(max_seqlen)}, options); static_cast<int64_t>(max_seqlen)}, options);
...@@ -266,10 +299,8 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -266,10 +299,8 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV; TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
auto max_tokens = dQKV.size(0); auto d = q_shape[q_shape.size() - 1];
auto self_2d = dQKV.view({max_tokens, -1}); if (set_zero && ((h * d) % block_size == 0)) {
auto fcd_size = self_2d.size(1);
if (set_zero && (fcd_size % block_size == 0)) {
mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else { } else {
dQKV.fill_(0); dQKV.fill_(0);
...@@ -283,35 +314,33 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -283,35 +314,33 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV");
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
} }
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); qkv_type, nullptr, nullptr, descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs, h, d}, te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); qkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0}, te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32,
DType::kFloat32,
nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr()); nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr());
at::Tensor descale_dP = torch::empty_like(scale_dP.value()); at::Tensor descale_dP = torch::empty_like(scale_dP.value());
te_dP = makeTransformerEngineTensor(nullptr, {0}, te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), DType::kFloat32, amax_dP.value().data_ptr(), scale_dP.value().data_ptr(),
descale_dP.data_ptr()); descale_dP.data_ptr());
te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), {total_seqs, 3, h, d}, te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, qkv_type,
qkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16 // BF16 or FP16
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs, h, d}, te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0}, te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr); DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0}, te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr); DType::kFloat32, nullptr, nullptr, nullptr);
te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), {total_seqs, 3, h, d}, te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
} else { } else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
...@@ -330,8 +359,9 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -330,8 +359,9 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
} }
// create cu_seqlens tensorwrappers // create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens; auto cu_seqlens_sizes = cu_seqlens.sizes().vec();
te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1}, std::vector<size_t> cu_seqlens_shape{cu_seqlens_sizes.begin(), cu_seqlens_sizes.end()};
TensorWrapper te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
// create workspace // create workspace
...@@ -385,9 +415,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -385,9 +415,7 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
// fused attention FWD with packed KV // fused attention FWD with packed KV
std::vector<at::Tensor> fused_attn_fwd_kvpacked( std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero, bool is_training, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_q,
...@@ -405,16 +433,23 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -405,16 +433,23 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t rng_elts_per_thread) { size_t rng_elts_per_thread) {
using namespace transformer_engine; using namespace transformer_engine;
auto q_sizes = Q.sizes().vec();
std::vector<size_t> q_shape{q_sizes.begin(), q_sizes.end()};
auto kv_sizes = KV.sizes().vec();
std::vector<size_t> kv_shape{kv_sizes.begin(), kv_sizes.end()};
std::vector<int64_t> o_shape{q_shape.begin(), q_shape.end()};
// create output tensor O // create output tensor O
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto O = torch::empty({static_cast<int64_t>(total_seqs_q), auto O = torch::empty(o_shape, options);
static_cast<int64_t>(h), static_cast<int64_t>(d)}, options);
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
if (set_zero && (h * d % block_size == 0)) { auto h = q_shape[q_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0)) {
mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else { } else {
O.fill_(0); O.fill_(0);
...@@ -424,38 +459,42 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -424,38 +459,42 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O";
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
} }
te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
at::Tensor descale_S = torch::empty_like(scale_S.value()); at::Tensor descale_S = torch::empty_like(scale_S.value());
te_S = makeTransformerEngineTensor(nullptr, {0}, te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_S.value().data_ptr(), DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.data_ptr()); scale_S.value().data_ptr(), descale_S.data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16 // BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0}, te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr); DType::kFloat32, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
} else { } else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
} }
if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
auto bias_shape = Bias.value().sizes().vec(); auto bias_sizes = Bias.value().sizes().vec();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()}; std::vector<size_t> bias_shape{bias_sizes.begin(), bias_sizes.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape, te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape,
DType::kFloat32, nullptr, nullptr, nullptr); DType::kFloat32, nullptr, nullptr, nullptr);
} }
te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), {b+1}, auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec();
std::vector<size_t> cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()};
auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec();
std::vector<size_t> cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()};
te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1}, te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
// extract rng seed and offset // extract rng seed and offset
...@@ -505,8 +544,18 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -505,8 +544,18 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
// allocate memory for nvte_aux_tensor_pack.tensors // allocate memory for nvte_aux_tensor_pack.tensors
at::Tensor output_tensor; at::Tensor output_tensor;
if (nvte_aux_tensor_pack.size >= 2) { if (nvte_aux_tensor_pack.size >= 2) {
output_tensor = (i < nvte_aux_tensor_pack.size-1) if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state; if (i < nvte_aux_tensor_pack.size - 2) {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
} else if (i == nvte_aux_tensor_pack.size - 2) {
output_tensor = rng_state;
} else if (i == nvte_aux_tensor_pack.size - 1) {
output_tensor = Bias.value();
}
} else {
output_tensor = (i < nvte_aux_tensor_pack.size-1)
? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state;
}
} else { } else {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
} }
...@@ -540,9 +589,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -540,9 +589,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
// fused attention BWD with packed KV // fused attention BWD with packed KV
std::vector<at::Tensor> fused_attn_bwd_kvpacked( std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_q,
...@@ -564,14 +611,28 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -564,14 +611,28 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
c10::optional<at::Tensor> amax_dQKV) { c10::optional<at::Tensor> amax_dQKV) {
using namespace transformer_engine; using namespace transformer_engine;
auto q_sizes = Q.sizes().vec();
std::vector<size_t> q_shape{q_sizes.begin(), q_sizes.end()};
auto kv_sizes = KV.sizes().vec();
std::vector<size_t> kv_shape{kv_sizes.begin(), kv_sizes.end()};
std::vector<size_t> k_shape;
for (auto i : kv_shape) {
if (i != 2) {
k_shape.push_back(i);
}
}
auto h_q = q_shape[q_shape.size() - 2];
auto h_kv = k_shape[k_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
// create output tensors dQ and dKV // create output tensors dQ and dKV
at::Tensor dQ = torch::empty_like(Q); at::Tensor dQ = torch::empty_like(Q);
at::Tensor dKV = torch::empty_like(KV); at::Tensor dKV = torch::empty_like(KV);
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias; at::Tensor dBias;
TensorWrapper te_dBias; TensorWrapper te_dBias;
if (bias_type != NVTE_NO_BIAS) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
dBias = torch::empty({1, static_cast<int64_t>(h), dBias = torch::empty({1, static_cast<int64_t>(h_q),
static_cast<int64_t>(max_seqlen_q), static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options); static_cast<int64_t>(max_seqlen_kv)}, options);
te_dBias = makeTransformerEngineTensor(dBias); te_dBias = makeTransformerEngineTensor(dBias);
...@@ -581,13 +642,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -581,13 +642,7 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV; TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
auto max_tokens_q = dQ.size(0); if (set_zero && ((h_q * d)% block_size == 0) && ((h_kv * d)% block_size == 0)) {
auto self_2d_q = dQ.view({max_tokens_q, -1});
auto fcd_size_q = self_2d_q.size(1);
auto max_tokens_kv = dQ.size(0);
auto self_2d_kv = dQ.view({max_tokens_kv, -1});
auto fcd_size_kv = self_2d_kv.size(1);
if (set_zero && (fcd_size_q % block_size == 0) && (fcd_size_kv % block_size == 0)) {
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else { } else {
...@@ -603,13 +658,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -603,13 +658,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV");
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
} }
te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); qkv_type, nullptr, nullptr, descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs_q, h, d}, te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); qkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr,
scale_S.value().data_ptr(), descale_S.value().data_ptr()); scale_S.value().data_ptr(), descale_S.value().data_ptr());
...@@ -617,37 +672,41 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -617,37 +672,41 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32,
amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), amax_dP.value().data_ptr(), scale_dP.value().data_ptr(),
descale_dP.data_ptr()); descale_dP.data_ptr());
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), {total_seqs_q, h, d}, qkv_type, te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, qkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), {total_seqs_kv, 2, h, d}, qkv_type, te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, qkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16 // BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs_q, h, d}, te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0}, te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr); DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0}, te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr); DType::kFloat32, nullptr, nullptr, nullptr);
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), {total_seqs_q, h, d}, te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), {total_seqs_kv, 2, h, d}, te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
} else { } else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
} }
// create cu_seqlens tensorwrappers // create cu_seqlens tensorwrappers
auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec();
std::vector<size_t> cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()};
auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec();
std::vector<size_t> cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()};
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), {b+1}, te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1}, te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
// convert auxiliary tensors from forward to NVTETensors // convert auxiliary tensors from forward to NVTETensors
...@@ -753,8 +812,8 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -753,8 +812,8 @@ std::vector<at::Tensor> fused_attn_fwd(
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
auto h = Q.size(-2); auto h = q_shape[q_shape.size() - 2];
auto d = Q.size(-1); auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0)) { if (set_zero && ((h * d) % block_size == 0)) {
mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else { } else {
...@@ -792,7 +851,7 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -792,7 +851,7 @@ std::vector<at::Tensor> fused_attn_fwd(
} else { } else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
} }
if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
auto bias_sizes = Bias.value().sizes().vec(); auto bias_sizes = Bias.value().sizes().vec();
std::vector<size_t> bias_shape{bias_sizes.begin(), bias_sizes.end()}; std::vector<size_t> bias_shape{bias_sizes.begin(), bias_sizes.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape,
...@@ -856,8 +915,18 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -856,8 +915,18 @@ std::vector<at::Tensor> fused_attn_fwd(
// allocate memory for nvte_aux_tensor_pack.tensors // allocate memory for nvte_aux_tensor_pack.tensors
at::Tensor output_tensor; at::Tensor output_tensor;
if (nvte_aux_tensor_pack.size >= 2) { if (nvte_aux_tensor_pack.size >= 2) {
output_tensor = (i < nvte_aux_tensor_pack.size-1) if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state; if (i < nvte_aux_tensor_pack.size - 2) {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
} else if (i == nvte_aux_tensor_pack.size - 2) {
output_tensor = rng_state;
} else if (i == nvte_aux_tensor_pack.size - 1) {
output_tensor = Bias.value();
}
} else {
output_tensor = (i < nvte_aux_tensor_pack.size-1)
? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state;
}
} else { } else {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
} }
...@@ -988,7 +1057,7 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -988,7 +1057,7 @@ std::vector<at::Tensor> fused_attn_bwd(
at::Tensor dBias; at::Tensor dBias;
TensorWrapper te_dBias; TensorWrapper te_dBias;
if (bias_type != NVTE_NO_BIAS) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
dBias = torch::empty({1, static_cast<int64_t>(Q.size(-2)), dBias = torch::empty({1, static_cast<int64_t>(Q.size(-2)),
static_cast<int64_t>(max_seqlen_q), static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options); static_cast<int64_t>(max_seqlen_kv)}, options);
...@@ -999,9 +1068,9 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -999,9 +1068,9 @@ std::vector<at::Tensor> fused_attn_bwd(
TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
auto h_q = Q.size(-2); auto h_q = q_shape[q_shape.size() - 2];
auto h_kv = K.size(-2); auto h_kv = k_shape[k_shape.size() - 2];
auto d = Q.size(-1); auto d = q_shape[q_shape.size() - 1];
if (set_zero if (set_zero
&& ((h_q * d) % block_size == 0) && ((h_q * d) % block_size == 0)
&& ((h_kv * d) % block_size == 0) && ((h_kv * d) % block_size == 0)
...@@ -1078,7 +1147,7 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1078,7 +1147,7 @@ std::vector<at::Tensor> fused_attn_bwd(
std::vector<size_t> cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; std::vector<size_t> cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()};
auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec();
std::vector<size_t> cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; std::vector<size_t> cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()};
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv, te_qkvso_strides; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
......
...@@ -149,17 +149,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -149,17 +149,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type") py::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type")
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS)
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)
.value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI);
py::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type") py::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type")
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK)
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK);
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout") py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
.value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)
.value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
.value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED)
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD)
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D)
.value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD)
......
...@@ -228,7 +228,7 @@ class _NoopCat(torch.autograd.Function): ...@@ -228,7 +228,7 @@ class _NoopCat(torch.autograd.Function):
), "Dimensions not compatible for concatenation" ), "Dimensions not compatible for concatenation"
param_temp = full_param_buffer.new() param_temp = full_param_buffer.new()
param_temp.set_(full_param_buffer.storage(), param_temp.set_(full_param_buffer.untyped_storage(),
full_param_buffer.storage_offset(), full_param_buffer.storage_offset(),
full_param_buffer.size(), full_param_buffer.size(),
full_param_buffer.stride()) full_param_buffer.stride())
......
...@@ -70,8 +70,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -70,8 +70,8 @@ class TransformerLayer(torch.nn.Module):
.. note:: .. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when Argument :attr:`attention_mask` in the `forward` call is only used when
:attr:`self_attn_mask_type` is set to `"causal"`. :attr:`self_attn_mask_type` includes `"padding"` or `"arbitrary"`.
Parameters Parameters
---------- ----------
...@@ -127,7 +127,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -127,7 +127,8 @@ class TransformerLayer(torch.nn.Module):
kv_channels: int, default = `None` kv_channels: int, default = `None`
number of key-value channels. defaults to number of key-value channels. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`. :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal` self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
default = `causal`
type of attention mask passed into softmax operation. Overridden by type of attention mask passed into softmax operation. Overridden by
:attr:`self_attn_mask_type` in the `forward` method. The forward :attr:`self_attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different arg is useful for dynamically changing mask types, e.g. a different
...@@ -491,7 +492,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -491,7 +492,7 @@ class TransformerLayer(torch.nn.Module):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
self_attn_mask_type: Optional[str] = None, self_attn_mask_type: Optional[str] = None,
encoder_output: Optional[torch.Tensor] = None, encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
...@@ -512,17 +513,23 @@ class TransformerLayer(torch.nn.Module): ...@@ -512,17 +513,23 @@ class TransformerLayer(torch.nn.Module):
---------- ----------
hidden_states : torch.Tensor hidden_states : torch.Tensor
Input tensor. Input tensor.
attention_mask : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None` attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input. Boolean tensor used to mask out self-attention softmax input.
Can be a tuple of 2 masks for cross attention with padding masks. It should be in [batch_size, 1, 1, seqlen_q] for 'padding' mask,
self_attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal` and broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
type of attention mask passed into softmax operation. for 'arbitrary'. It should be 'None' for 'causal' and 'no_mask'.
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `causal`
Type of attention mask passed into softmax operation.
encoder_output : Optional[torch.Tensor], default = `None` encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`. `layer_type="decoder"`.
enc_dec_attn_mask : Optional[torch.Tensor], default = `None` enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
Boolean tensor used to mask out inter-attention softmax input if using default = `None`. Boolean tensors used to mask out inter-attention softmax input if
`layer_type="decoder"`. using `layer_type="decoder"`. It should be a tuple of two masks in
[batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for 'padding' mask.
It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
for 'arbitrary' mask. It should be 'None' for 'causal' and 'no_mask'.
is_first_microbatch : {True, False, None}, default = None is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split pipeline parallelism a minibatch of data is further split
...@@ -545,7 +552,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -545,7 +552,7 @@ class TransformerLayer(torch.nn.Module):
Embeddings for query and key tensors for applying rotary position Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied. embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias` core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`} Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
core_attention_bias: Optional[torch.Tensor], default = `None` core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T Bias tensor for Q * K.T
fast_zero_fill: bool, default = `True` fast_zero_fill: bool, default = `True`
...@@ -569,7 +576,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -569,7 +576,9 @@ class TransformerLayer(torch.nn.Module):
hidden_states.shape[0] == self.seq_length // self.tp_size hidden_states.shape[0] == self.seq_length // self.tp_size
), "Sequence dimension must be split across TP group when using sequence parallel." ), "Sequence dimension must be split across TP group when using sequence parallel."
if self_attn_mask_type != "causal" and attention_mask is not None: if (("padding" in self_attn_mask_type
or self_attn_mask_type == "arbitrary")
and attention_mask is not None):
assert ( assert (
attention_mask.dtype == torch.bool attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor" ), "Attention mask must be a boolean tensor"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment