Unverified Commit d7704b98 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Fix header files for doxygen (#252)



* fix headers for doxygen
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix description f16 and use half precision instead
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f70b4bbf
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
/*! \file fused_attn_max_512.h /*! \file fused_attn_fp16_bf16_max_seqlen_512.h
* \brief Functions for fused attention with seqlen <= 512 * \brief Functions for fused attention for half precision with seqlen <= 512
*/ */
#ifndef TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_ #ifndef TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_
......
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
/*! \file fused_attn_fp8.h
* \brief Functions for fused attention for FP8 with seqlen <= 512
*/
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
namespace transformer_engine { namespace transformer_engine {
......
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
/*! \file fused_attn.h
* \brief Enums and functions for fused attention.
*/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ #ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ #define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
...@@ -13,75 +17,97 @@ ...@@ -13,75 +17,97 @@
extern "C" { extern "C" {
#endif #endif
/*! \enum NVTE_QKV_Layout
* \brief QKV matrix layouts
*/
enum NVTE_QKV_Layout { enum NVTE_QKV_Layout {
/*!< separate Q, K, V tensors: /*! Separate Q, K, V tensors.
Q: [total_seqs_q, num_heads, head_dim] \verbatim
| Q Q Q ... Q Q: [total_seqs_q, num_heads, head_dim]
| \___________ _____________/ | Q Q Q ... Q
total_seqs_q <| \/ | \___________ _____________/
| num_heads * head_dim total_seqs_q <| \/
K: [total_seqs_kv, num_heads, head_dim] | num_heads * head_dim
| K K K ... K K: [total_seqs_kv, num_heads, head_dim]
| \___________ _____________/ | K K K ... K
total_seqs_kv <| \/ | \___________ _____________/
| num_heads * head_dim total_seqs_kv <| \/
V: [total_seqs_kv, num_heads, head_dim] | num_heads * head_dim
| V V V ... V V: [total_seqs_kv, num_heads, head_dim]
| \___________ _____________/ | V V V ... V
total_seqs_kv <| \/ | \___________ _____________/
| num_heads * head_dim total_seqs_kv <| \/
| num_heads * head_dim
\endverbatim
*/ */
NVTE_NOT_INTERLEAVED = 0, NVTE_NOT_INTERLEAVED = 0,
/*!< packed QKV tensor: /*! Packed QKV.
QKV: [total_seqs, 3, num_heads, head_dim] \verbatim
| Q Q Q ... Q K K K ... K V V V ... V QKV: [total_seqs, 3, num_heads, head_dim]
| \___________ _____________/ | Q Q Q ... Q K K K ... K V V V ... V
total_seqs <| \/ | \___________ _____________/
| num_heads * head_dim total_seqs <| \/
| num_heads * head_dim
\endverbatim
*/ */
NVTE_QKV_INTERLEAVED = 1, NVTE_QKV_INTERLEAVED = 1,
/*!< Q and packed KV tensor: /*! Q and packed KV.
Q: [total_seqs_q, num_heads, head_dim] \verbatim
| Q Q Q ... Q Q: [total_seqs_q, num_heads, head_dim]
| \___________ _____________/ | Q Q Q ... Q
total_seqs_q <| \/ | \___________ _____________/
| num_heads * head_dim total_seqs_q <| \/
KV: [total_seqs_kv, 2, num_heads, head_dim] | num_heads * head_dim
| K K K ... K V V V ... V KV: [total_seqs_kv, 2, num_heads, head_dim]
| \___________ _____________/ | K K K ... K V V V ... V
total_seqs_kv <| \/ | \___________ _____________/
| num_heads * head_dim total_seqs_kv <| \/
| num_heads * head_dim
\endverbatim
*/ */
NVTE_KV_INTERLEAVED = 2 NVTE_KV_INTERLEAVED = 2
}; };
/*! \enum NVTE_Bias_Type
* \brief Bias types
*/
enum NVTE_Bias_Type { enum NVTE_Bias_Type {
NVTE_NO_BIAS = 0, /*!< no bias */ /*! No bias */
NVTE_PRE_SCALE_BIAS = 1, /*!< bias before scale */ NVTE_NO_BIAS = 0,
NVTE_POST_SCALE_BIAS = 2 /*!< bias after scale */ /*! Bias before scale */
NVTE_PRE_SCALE_BIAS = 1,
/*! Bias after scale */
NVTE_POST_SCALE_BIAS = 2
}; };
/*! \enum NVTE_Mask_Type
* \brief Attention mask types
*/
enum NVTE_Mask_Type { enum NVTE_Mask_Type {
NVTE_PADDING_MASK = 0, /*!< padding attention mask */ /*! No masking */
NVTE_CAUSAL_MASK = 1, /*!< causal attention mask */ NVTE_NO_MASK = 0,
NVTE_NO_MASK = 2 /*!< no masking */ /*! Padding attention mask */
NVTE_PADDING_MASK = 1,
/*! Causal attention mask */
NVTE_CAUSAL_MASK = 2,
}; };
/*! \brief Compute dot product attention with packed QKV input. /*! \brief Compute dot product attention with packed QKV input.
* *
* Computes: * Computes:
* - P = Q * K.T + Bias * - P = Q * Transpose(K) + Bias
* - S = ScaleMaskSoftmax(P) * - S = ScaleMaskSoftmax(P)
* - D = Dropout(S) * - D = Dropout(S)
* - O = D * V.T * - O = D * Transpose(V)
* *
* Support Matrix: * Support Matrix:
* | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | \verbatim
* | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
* | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 | | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
* | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |
\endverbatim
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
* [total_seqs, 3, num_heads, head_dim]. * [total_seqs, 3, num_heads, head_dim].
...@@ -91,8 +117,8 @@ enum NVTE_Mask_Type { ...@@ -91,8 +117,8 @@ enum NVTE_Mask_Type {
* \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv. * \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen Max sequence length used for computing, * \param[in] max_seqlen Max sequence length used for computing.
* it may be >= max(cu_seqlens). * It may be >= max(cu_seqlens).
* \param[in] is_training Whether this is in training mode or inference. * \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
...@@ -120,10 +146,11 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -120,10 +146,11 @@ void nvte_fused_attn_fwd_qkvpacked(
/*! \brief Compute the backward of the dot product attention with packed QKV input. /*! \brief Compute the backward of the dot product attention with packed QKV input.
* *
* Support Matrix: * Support Matrix:
* | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | \verbatim
* | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
* | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 | | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
* | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |
\endverbatim
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
* [total_seqs, 3, num_heads, head_dim]. * [total_seqs, 3, num_heads, head_dim].
...@@ -135,8 +162,8 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -135,8 +162,8 @@ void nvte_fused_attn_fwd_qkvpacked(
* \param[out] dQKV The gradient of the QKV tensor. * \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing, * \param[in] max_seqlen Max sequence length used for computing.
* it may be >= max(cu_seqlens). * It may be >= max(cu_seqlens).
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
...@@ -165,15 +192,16 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -165,15 +192,16 @@ void nvte_fused_attn_bwd_qkvpacked(
/*! \brief Compute dot product attention with packed KV input. /*! \brief Compute dot product attention with packed KV input.
* *
* Computes: * Computes:
* - P = Q * K.T + Bias * - P = Q * Transpose(K) + Bias
* - S = ScaleMaskSoftmax(P) * - S = ScaleMaskSoftmax(P)
* - D = Dropout(S) * - D = Dropout(S)
* - O = D * V.T * - O = D * Transpose(V)
* *
* Support Matrix: * Support Matrix:
* | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | \verbatim
* | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 | | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
* | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |
\endverbatim
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
...@@ -184,10 +212,10 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -184,10 +212,10 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_q Max sequence length used for computing
* it may be >= max(cu_seqlens_q). * for Q. It may be >= max(cu_seqlens_q).
* \param[in] max_seqlen_kv Max sequence length used for computing for KV. * \param[in] max_seqlen_kv Max sequence length used for computing
* it may be >= max(cu_seqlens_kv). * for KV. It may be >= max(cu_seqlens_kv).
* \param[in] is_training Whether this is in training mode or inference. * \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
...@@ -217,9 +245,10 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -217,9 +245,10 @@ void nvte_fused_attn_fwd_kvpacked(
/*! \brief Compute the backward of the dot product attention with packed KV input. /*! \brief Compute the backward of the dot product attention with packed KV input.
* *
* Support Matrix: * Support Matrix:
* | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | \verbatim
* | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 | | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
* | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |
\endverbatim
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
...@@ -233,10 +262,10 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -233,10 +262,10 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_q Max sequence length used for computing
* it may be >= max(cu_seqlens_q). * for Q. It may be >= max(cu_seqlens_q).
* \param[in] max_seqlen_kv Max sequence length used for computing for KV. * \param[in] max_seqlen_kv Max sequence length used for computing
* it may be >= max(cu_seqlens_kv). * for KV. It may be >= max(cu_seqlens_kv).
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
......
...@@ -130,17 +130,23 @@ float *nvte_tensor_scale(const NVTETensor tensor); ...@@ -130,17 +130,23 @@ float *nvte_tensor_scale(const NVTETensor tensor);
*/ */
float *nvte_tensor_scale_inv(const NVTETensor tensor); float *nvte_tensor_scale_inv(const NVTETensor tensor);
/*! \struct NVTETensorPack
\brief Pack of tensors, generally used for auxiliary outputs.
*/
struct NVTETensorPack { struct NVTETensorPack {
static const int MAX_SIZE = 10; /*!< we expect <10 matrices in auxiliary outputs */ /*! Max number of tensors in the pack. Assumed <= 10. */
NVTETensor tensors[MAX_SIZE]; /*!< wrappers to tensors, do not hold memory */ static const int MAX_SIZE = 10;
size_t size = 0; /*!< actual size of the tensor pack, 0 <= size <= MAX_SIZE */ /*! Wrappers of tensors. They do not hold the associated memory. */
NVTETensor tensors[MAX_SIZE];
/*! Actual number of tensors in the pack, 0 <= size <= MAX_SIZE. */
size_t size = 0;
}; };
/*! \brief Create NVTETensors in NVTETensorPack. /*! \brief Create `tensors` in NVTETensorPack.
*/ */
void nvte_tensor_pack_create(NVTETensorPack* pack); void nvte_tensor_pack_create(NVTETensorPack* pack);
/*! \brief Destroy NVTETensors in NVTETensorPack. /*! \brief Destroy `tensors` in NVTETensorPack.
*/ */
void nvte_tensor_pack_destroy(NVTETensorPack* pack); void nvte_tensor_pack_destroy(NVTETensorPack* pack);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment