Unverified Commit 3454f84d authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[common] Remove kvpacked and qkvpacked attention functions for every kernel type. (#2287)



* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* depracted compile time warning + \warning -> \deprecated
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent d20311bd
......@@ -15,6 +15,74 @@
#include "fused_attn_fp8.h"
#include "utils.h"
namespace {
// Helper function to create a tensor view with modified shape and optional pointer offset
transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor *source,
const std::vector<size_t> &shape,
size_t offset_bytes = 0) {
transformer_engine::Tensor view = *source;
if (offset_bytes > 0) {
view.data.dptr = static_cast<void *>(static_cast<int8_t *>(source->data.dptr) + offset_bytes);
}
view.data.shape = shape;
view.nvte_tensor = 0; // Mark as unmanaged/local tensor view
return view;
}
// Helper function to calculate stride for packed QKV tensor unpacking
size_t calculate_qkv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype,
size_t h, size_t d) {
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = (transformer_engine::typeToNumBits(dtype) * h * d) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = (transformer_engine::typeToNumBits(dtype) * d) / 8;
}
return stride;
}
// Helper function to determine unpacked shape for QKV packed tensor
std::vector<size_t> calculate_qkv_unpacked_shape(const transformer_engine::Tensor *qkv_tensor,
size_t h, size_t d) {
std::vector<size_t> unpacked_shape;
if (qkv_tensor->data.shape.size() == 4) {
// T3HD or TH3D (4D) -> THD (3D): remove dimension "3" at position 1
unpacked_shape = {qkv_tensor->data.shape[0], h, d};
} else {
// BS3HD/SB3HD or BSH3D/SBH3D (5D) -> BSHD/SBHD (4D): remove dimension "3" at position 2
unpacked_shape = {qkv_tensor->data.shape[0], qkv_tensor->data.shape[1], h, d};
}
return unpacked_shape;
}
// Helper function to calculate stride for packed KV tensor unpacking
size_t calculate_kv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype,
size_t h_kv, size_t d) {
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = (transformer_engine::typeToNumBits(dtype) * h_kv * d) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = (transformer_engine::typeToNumBits(dtype) * d) / 8;
}
return stride;
}
// Helper function to determine unpacked shape for KV packed tensor
std::vector<size_t> calculate_kv_unpacked_shape(const transformer_engine::Tensor *kv_tensor,
NVTE_QKV_Layout_Group layout_group,
NVTE_QKV_Format kv_format, size_t t_kv, size_t h_kv,
size_t d) {
std::vector<size_t> unpacked_kv_shape;
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
unpacked_kv_shape = {t_kv, h_kv, d};
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD ||
layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
unpacked_kv_shape = {kv_tensor->data.shape[0], kv_tensor->data.shape[1], h_kv, d};
}
return unpacked_kv_shape;
}
} // namespace
// map NVTE_QKV_Layout to NVTE_QKV_Layout_Group
NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
switch (qkv_layout) {
......@@ -436,6 +504,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
}
// NVTE fused attention FWD with packed QKV
// DEPRECATED: This API is deprecated.
// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead.
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
......@@ -487,30 +557,62 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_QKV, input_Bias,
output_O, Aux_CTX_Tensors, input_cu_seqlens, input_rng_state,
wkspace, stream, handle);
// Unpack QKV and call the non-packed function
const auto QKV_type = input_QKV->data.dtype;
size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d);
std::vector<size_t> unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d);
// Create tensor views for Q, K, V
Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape);
Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride);
Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride);
fused_attn_max_512_fwd(b, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view,
input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens,
input_cu_seqlens, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, t, is_training, return_max_logit, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV,
input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens,
input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
// Unpack QKV and call the non-packed function
const auto QKV_type = input_QKV->data.dtype;
size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d);
std::vector<size_t> unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d);
// Create tensor views for Q, K, V
Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape);
Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride);
Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride);
fused_attn_arbitrary_seqlen_fwd(
b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training,
return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens,
input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_QKV, input_output_S, output_O,
Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, wkspace,
stream, handle);
// Unpack QKV and call the non-packed function
const auto QKV_type = input_QKV->data.dtype;
size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d);
std::vector<size_t> unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d);
// Create tensor views for Q, K, V
Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape);
Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride);
Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride);
fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view,
input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens,
input_cu_seqlens, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
......@@ -519,6 +621,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
}
}
// NVTE fused attention BWD with packed QKV
// DEPRECATED: This API is deprecated.
// Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead.
void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S,
NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias,
......@@ -570,9 +674,25 @@ void nvte_fused_attn_bwd_qkvpacked(
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_qkvpacked(
b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV,
input_dO, output_S, output_dQKV, output_dBias, input_cu_seqlens, wkspace, stream, handle);
// Unpack QKV and dQKV and call the non-packed function
const auto QKV_type = input_QKV->data.dtype;
size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d);
std::vector<size_t> unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d);
// Create tensor views for Q, K, V and dQ, dK, dV
Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape);
Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride);
Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride);
Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape);
Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride);
Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride);
fused_attn_max_512_bwd(b, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_dO, output_S,
&dQ_view, &dK_view, &dV_view, output_dBias, input_cu_seqlens,
input_cu_seqlens, wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
......@@ -588,12 +708,27 @@ void nvte_fused_attn_bwd_qkvpacked(
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size_left, window_size_right, deterministic, input_QKV, input_O,
input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQKV, output_dBias,
output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace,
stream, handle);
// Unpack QKV and dQKV and call the non-packed function
const auto QKV_type = input_QKV->data.dtype;
size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d);
std::vector<size_t> unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d);
// Create tensor views for Q, K, V and dQ, dK, dV
Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape);
Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride);
Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride);
Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape);
Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride);
Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride);
fused_attn_arbitrary_seqlen_bwd(
b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view,
&K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view,
&dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens,
input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
......@@ -605,10 +740,26 @@ void nvte_fused_attn_bwd_qkvpacked(
const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_qkvpacked(b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv,
input_S, input_output_dP, output_dQKV, input_cu_seqlens,
input_rng_state, wkspace, stream, handle);
// Unpack QKV and dQKV and call the non-packed function
const auto QKV_type = input_QKV->data.dtype;
size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d);
std::vector<size_t> unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d);
// Create tensor views for Q, K, V and dQ, dK, dV
Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape);
Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride);
Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride);
Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape);
Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride);
Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride);
fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO,
input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view,
input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream,
handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
......@@ -617,6 +768,8 @@ void nvte_fused_attn_bwd_qkvpacked(
}
}
// NVTE fused attention FWD with packed KV
// DEPRECATED: This API is deprecated.
// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead.
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
......@@ -706,21 +859,40 @@ void nvte_fused_attn_fwd_kvpacked(
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_kvpacked(
b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle);
// Unpack KV and call the non-packed function
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d);
std::vector<size_t> unpacked_kv_shape =
calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d);
Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape);
Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride);
fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view,
input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v,
// Unpack KV and call the non-packed function
const auto Q_type = input_Q->data.dtype;
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d);
std::vector<size_t> unpacked_kv_shape =
calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d);
Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape);
Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride);
fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training,
return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset,
output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k,
input_page_table_v, input_rng_state, wkspace, stream, handle);
#else
......@@ -729,10 +901,20 @@ void nvte_fused_attn_fwd_kvpacked(
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_KV, input_output_S, output_O, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle);
// Unpack KV and call the non-packed function
const auto Q_type = input_Q->data.dtype;
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d);
std::vector<size_t> unpacked_kv_shape =
calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d);
Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape);
Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride);
fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view,
input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
......@@ -741,6 +923,8 @@ void nvte_fused_attn_fwd_kvpacked(
}
}
// NVTE fused attention BWD with packed KV
// DEPRECATED: This API is deprecated.
// Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead.
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
......@@ -806,10 +990,23 @@ void nvte_fused_attn_bwd_kvpacked(
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_kvpacked(
b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_KV, input_dO, output_S, output_dQ, output_dKV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle);
// Unpack KV and dKV and call the non-packed function
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d);
std::vector<size_t> unpacked_kv_shape =
calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d);
Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape);
Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride);
Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape);
Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride);
fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_dO, output_S,
output_dQ, &dK_view, &dV_view, output_dBias, input_cu_seqlens_q,
input_cu_seqlens_kv, wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
......@@ -825,13 +1022,29 @@ void nvte_fused_attn_bwd_kvpacked(
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout,
// Unpack KV and dKV and call the non-packed function
const auto Q_type = input_Q->data.dtype;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d);
std::vector<size_t> unpacked_kv_shape =
calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d);
Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape);
Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride);
// Create tensor views for dK, dV
Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape);
Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride);
fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic,
input_Q, input_KV, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ,
output_dKV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S,
output_dQ, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state,
wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention "
......@@ -843,11 +1056,25 @@ void nvte_fused_attn_bwd_kvpacked(
const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_kvpacked(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O,
input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ,
output_dKV, input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state, wkspace, stream, handle);
// Unpack KV and dKV and call the non-packed function
const auto Q_type = input_Q->data.dtype;
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d);
std::vector<size_t> unpacked_kv_shape =
calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d);
Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape);
Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride);
Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape);
Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride);
fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O,
input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view,
&dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace,
stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
......
......@@ -1037,532 +1037,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
} // namespace fused_attn
using namespace transformer_engine::fused_attn;
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, bool return_max_logit, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens,
const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_QKV->data.dtype;
void *devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = (typeToNumBits(QKV_type) * head_dim) / 8;
}
void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void *devPtrBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
void *devPtrSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
}
void *devPtrO = output_O->data.dptr;
void *devPtrS1 = nullptr;
void *devPtrS2 = nullptr;
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr;
size_t max_batch_size = 0;
size_t max_tokens = 0;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch);
max_tokens = get_max_tokens(num_tokens);
}
size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();
if (return_max_logit) {
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Max->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_Max->data.shape = {max_tokens, num_attn_heads, 1};
} else {
output_Max->data.shape = {batch, num_attn_heads, max_seqlen, 1};
}
output_Max->data.dtype = DType::kFloat32;
Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Sum_Exp->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_Sum_Exp->data.shape = {max_tokens, num_attn_heads, 1};
} else {
output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen, 1};
}
output_Sum_Exp->data.dtype = DType::kFloat32;
} else {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
}
output_S->data.dtype = DType::kFloat32;
}
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen};
output_bias->data.dtype = QKV_type;
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = nullptr;
output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1};
output_softmax_offset->data.dtype = DType::kFloat32;
}
Aux_CTX_Tensors->size = i;
} else if (Aux_CTX_Tensors->size >= 2) {
if (return_max_logit) {
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS1 = output_Max->data.dptr;
Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS2 = output_Sum_Exp->data.dptr;
} else {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS1 = output_S->data.dptr;
}
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = rng_state->data.dptr;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = devPtrBias;
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = devPtrSoftmaxOffset;
}
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
reinterpret_cast<void *>(reinterpret_cast<uint64_t *>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim,
max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training,
return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias,
devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, devPtrSeqOffsets,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_QKV->data.dtype;
void *devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = (typeToNumBits(QKV_type) * head_dim) / 8;
}
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void *devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr;
void *devPtrBias = nullptr;
void *devPtrdBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
devPtrdBias = output_dBias->data.dptr;
bias_b = output_dBias->data.shape[0];
bias_h = output_dBias->data.shape[1];
}
size_t max_batch_size = 0;
size_t max_tokens = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch);
max_tokens = get_max_tokens(num_tokens);
}
void *devPtrdQKV = output_dQKV->data.dptr;
void *devPtrdQ = devPtrdQKV;
void *devPtrdK = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + stride);
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + 2 * stride);
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrSoftmaxOffset = nullptr;
void *devPtrdSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr;
}
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr;
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
reinterpret_cast<void *>(reinterpret_cast<uint64_t *>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim,
max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type, softmax_type, window_size_left, window_size_right, deterministic,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = (typeToNumBits(QKV_type) * head_dim) / 8;
}
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
void *devPtrBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
void *devPtrSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
}
void *devPtrO = output_O->data.dptr;
void *devPtrS1 = nullptr;
void *devPtrS2 = nullptr;
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
void *devPtrPageTableK = page_table_k->data.dptr;
void *devPtrPageTableV = page_table_v->data.dptr;
size_t max_batch_size = 0;
size_t max_tokens_q = 0;
size_t max_tokens_kv = 0;
if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch);
}
if (q_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_q = get_max_tokens(num_tokens_q);
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_kv = get_max_tokens(num_tokens_kv);
}
size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();
if (return_max_logit) {
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Max->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_Max->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_Max->data.dtype = DType::kFloat32;
Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Sum_Exp->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_Sum_Exp->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_Sum_Exp->data.dtype = DType::kFloat32;
} else {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
}
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type;
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = nullptr;
output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1};
output_softmax_offset->data.dtype = DType::kFloat32;
}
Aux_CTX_Tensors->size = i;
} else if (Aux_CTX_Tensors->size >= 2) {
if (return_max_logit) {
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS1 = output_Max->data.dptr;
Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS2 = output_Sum_Exp->data.dptr;
} else {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS1 = output_S->data.dptr;
}
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = rng_state->data.dptr;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = devPtrBias;
}
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_softmax_offset->data.dptr = devPtrSoftmaxOffset;
}
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
reinterpret_cast<void *>(reinterpret_cast<uint64_t *>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim,
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias,
devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ,
devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = (typeToNumBits(QKV_type) * head_dim) / 8;
}
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
void *devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr;
void *devPtrBias = nullptr;
void *devPtrdBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
devPtrdBias = output_dBias->data.dptr;
bias_b = output_dBias->data.shape[0];
bias_h = output_dBias->data.shape[1];
}
size_t max_batch_size = 0;
size_t max_tokens_q = 0;
size_t max_tokens_kv = 0;
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch);
}
if (q_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_q = get_max_tokens(num_tokens_q);
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_kv = get_max_tokens(num_tokens_kv);
}
void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdKV = output_dKV->data.dptr;
void *devPtrdK = devPtrdKV;
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdKV) + stride);
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrSoftmaxOffset = nullptr;
void *devPtrdSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr;
}
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
reinterpret_cast<void *>(reinterpret_cast<uint64_t *>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
......@@ -1604,8 +1078,8 @@ void fused_attn_arbitrary_seqlen_fwd(
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
void *devPtrPageTableK = page_table_k->data.dptr;
void *devPtrPageTableV = page_table_v->data.dptr;
void *devPtrPageTableK = page_table_k ? page_table_k->data.dptr : nullptr;
void *devPtrPageTableV = page_table_v ? page_table_v->data.dptr : nullptr;
size_t max_batch_size = 0;
size_t max_tokens_q = 0;
......
......@@ -18,53 +18,6 @@
namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, bool return_max_logit, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens,
const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
......
......@@ -1215,150 +1215,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
} // namespace fused_attn
using namespace transformer_engine::fused_attn;
void fused_attn_max_512_fwd_qkvpacked(
size_t batch, size_t num_head, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
const auto stride = 2 * num_head * head_dim;
void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void *devPtrBias = static_cast<void *>(input_Bias->data.dptr);
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen};
output_S->data.dtype = input_QKV->data.dtype;
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void *devPtrCuSeqlen = cu_seqlens->data.dptr;
const DType rng_state_type = rng_state->data.dtype;
NVTE_CHECK(rng_state_type == DType::kInt64);
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
static_cast<void *>(static_cast<uint64_t *>(rng_state->data.dptr) + 1);
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn_max_512_fwd_impl(
batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias,
devPtrCuSeqlen, devPtrCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(QKV_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS,
"NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512.");
// Q shape is [b, s, h, d]
void *devPtrQ = input_Q->data.dptr;
// KV shape is [b, s, 2, h, d]
const auto stride = 2 * num_head * head_dim;
void *devPtrK = input_KV->data.dptr;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrK) + stride);
void *devPtrBias = input_Bias->data.dptr;
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
const DType q_type = input_Q->data.dtype;
const DType kv_type = input_KV->data.dtype;
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
output_S->data.dtype = q_type;
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void *devQCuSeqlen = q_cu_seqlens->data.dptr;
void *devKVCuSeqlen = kv_cu_seqlens->data.dptr;
const DType rng_state_type = rng_state->data.dtype;
NVTE_CHECK(rng_state_type == DType::kInt64);
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
static_cast<void *>(static_cast<uint64_t *>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
fused_attn_max_512_fwd_impl(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias,
devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(q_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
......@@ -1429,126 +1285,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
}
}
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV,
Tensor *output_dBias, const Tensor *cu_seqlens,
Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
auto stride = 2 * num_head * head_dim;
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void *devPtrdO = input_dO->data.dptr;
// dQKV shape is [b, s, 3, h, d]
void *devPtrdQKV = output_dQKV->data.dptr;
void *devPtrdQ = devPtrdQKV;
void *devPtrdK = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + stride);
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + 2 * stride);
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrS = output_S->data.dptr;
// devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS;
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
const auto qkv_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn_max_512_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, attn_scale,
p_dropout, qkv_layout, mask_type, bias_type, devPtrQ, devPtrK,
devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdS,
devPtrdBias, devPtrCuSeqlens, devPtrCuSeqlens, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(qkv_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
// Q shape is [b, s, h, d]
// KV shape is [b, s, 2, h, d]
auto stride = 2 * num_head * head_dim;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_KV->data.dptr;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrK) + stride);
void *devPtrdO = input_dO->data.dptr;
// dQ shape is [b, s, h, d]
// dKV shape is [b, s, 2, h, d]
void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdK = output_dKV->data.dptr;
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdK) + stride);
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrS = output_S->data.dptr;
// devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS;
void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr;
void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr;
const auto q_type = input_Q->data.dtype;
const auto kv_type = input_KV->data.dtype;
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
size_t workspace_size = 0;
fused_attn_max_512_bwd_impl(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout,
mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV,
devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(q_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......
......@@ -18,25 +18,6 @@
namespace transformer_engine {
#if (CUDNN_VERSION >= 8901)
void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_size, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
......@@ -47,24 +28,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV,
Tensor *output_dBias, const Tensor *cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......
......@@ -2407,424 +2407,6 @@ void fused_attn_fp8_bwd_impl_v1(
} // namespace fused_attn
#if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV
void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen,
size_t head_dim, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor* input_QKV, Tensor* input_output_S, Tensor* output_O,
NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens,
const Tensor* rng_state, Tensor* workspace, cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_QKV->data.dtype;
const DType O_type = output_O->data.dtype;
void* devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = (typeToNumBits(QKV_type) * head_dim) / 8;
}
void* devPtrQ = static_cast<void*>(devPtrQKV);
void* devPtrK = static_cast<void*>(static_cast<int8_t*>(devPtrQKV) + stride);
void* devPtrV = static_cast<void*>(static_cast<int8_t*>(devPtrQKV) + 2 * stride);
void* devPtrDescaleQ = input_QKV->scale_inv.dptr;
void* devPtrDescaleK = input_QKV->scale_inv.dptr;
void* devPtrDescaleV = input_QKV->scale_inv.dptr;
void* devPtrO = output_O->data.dptr;
void* devPtrAmaxO = output_O->amax.dptr;
void* devPtrScaleO = output_O->scale.dptr;
void* devPtrM = nullptr;
void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 3;
Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr;
output_M->data.shape = {batch, num_attn_heads, max_seqlen, 1};
output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen, 1};
output_ZInv->data.dtype = DType::kFloat32;
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void* devPtrAmaxS = input_output_S->amax.dptr;
void* devPtrScaleS = input_output_S->scale.dptr;
void* devPtrDescaleS = input_output_S->scale_inv.dptr;
void* devPtrcuSeqlens =
reinterpret_cast<void*>(reinterpret_cast<int32_t*>(cu_seqlens->data.dptr));
void* devPtrDropoutSeed =
reinterpret_cast<void*>(reinterpret_cast<uint64_t*>(rng_state->data.dptr));
void* devPtrDropoutOffset =
reinterpret_cast<void*>(reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_fwd_impl_v1(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM,
devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS,
devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl(
batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout,
qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ,
devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO,
devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
} else {
NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n");
}
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
}
}
// fused attention BWD FP8 with packed QKV
void fused_attn_fp8_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M,
const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP,
const Tensor* output_dQKV, const Tensor* cu_seqlens, const Tensor* rng_state, Tensor* workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_QKV->data.dtype;
const DType dO_type = input_dO->data.dtype;
const DType dQKV_type = output_dQKV->data.dtype;
void* devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = (typeToNumBits(QKV_type) * head_dim) / 8;
}
void* devPtrQ = devPtrQKV;
void* devPtrK = static_cast<void*>(static_cast<int8_t*>(devPtrQKV) + stride);
void* devPtrV = static_cast<void*>(static_cast<int8_t*>(devPtrQKV) + 2 * stride);
void* devPtrDescaleQ = input_QKV->scale_inv.dptr;
void* devPtrDescaleK = input_QKV->scale_inv.dptr;
void* devPtrDescaleV = input_QKV->scale_inv.dptr;
void* devPtrO = input_O->data.dptr;
const DType O_type = input_O->data.dtype;
void* devPtrDescaleO = nullptr;
if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) {
devPtrDescaleO = input_O->scale_inv.dptr;
}
void* devPtrdO = input_dO->data.dptr;
void* devPtrDescaledO = input_dO->scale_inv.dptr;
void* devPtrM = input_M->data.dptr;
void* devPtrZInv = input_ZInv->data.dptr;
void* devPtrScaleS = input_S->scale.dptr;
void* devPtrDescaleS = input_S->scale_inv.dptr;
void* devPtrAmaxdP = input_output_dP->amax.dptr;
void* devPtrScaledP = input_output_dP->scale.dptr;
void* devPtrDescaledP = input_output_dP->scale_inv.dptr;
void* devPtrdQKV = output_dQKV->data.dptr;
void* devPtrdQ = devPtrdQKV;
void* devPtrdK = static_cast<void*>(static_cast<int8_t*>(devPtrdQKV) + stride);
void* devPtrdV = static_cast<void*>(static_cast<int8_t*>(devPtrdQKV) + 2 * stride);
void* devPtrAmaxdQ = output_dQKV->amax.dptr;
void* devPtrAmaxdK = output_dQKV->amax.dptr;
void* devPtrAmaxdV = output_dQKV->amax.dptr;
void* devPtrScaledQ = output_dQKV->scale.dptr;
void* devPtrScaledK = output_dQKV->scale.dptr;
void* devPtrScaledV = output_dQKV->scale.dptr;
void* devPtrcuSeqlens =
reinterpret_cast<void*>(reinterpret_cast<int32_t*>(cu_seqlens->data.dptr));
void* devPtrDropoutSeed =
reinterpret_cast<void*>(reinterpret_cast<uint64_t*>(rng_state->data.dptr));
void* devPtrDropoutOffset =
reinterpret_cast<void*>(reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_bwd_impl_v1(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale,
p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv,
devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK,
devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl(
batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK,
devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO,
devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK,
devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens,
devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else {
NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n");
}
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
}
}
// fused attention FWD FP8 with packed KV
void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor* input_Q,
const Tensor* input_KV, Tensor* input_output_S, Tensor* output_O,
NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q,
const Tensor* cu_seqlens_kv, const Tensor* rng_state,
Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
const DType O_type = output_O->data.dtype;
void* devPtrQ = input_Q->data.dptr;
void* devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = (typeToNumBits(QKV_type) * head_dim) / 8;
}
void* devPtrK = devPtrKV;
void* devPtrV = static_cast<void*>(static_cast<int8_t*>(devPtrKV) + stride);
void* devPtrDescaleQ = input_Q->scale_inv.dptr;
void* devPtrDescaleK = input_KV->scale_inv.dptr;
void* devPtrDescaleV = input_KV->scale_inv.dptr;
void* devPtrO = output_O->data.dptr;
void* devPtrAmaxO = output_O->amax.dptr;
void* devPtrScaleO = output_O->scale.dptr;
void* devPtrM = nullptr;
void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 3;
Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr;
output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_ZInv->data.dtype = DType::kFloat32;
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void* devPtrAmaxS = input_output_S->amax.dptr;
void* devPtrScaleS = input_output_S->scale.dptr;
void* devPtrDescaleS = input_output_S->scale_inv.dptr;
void* devPtrcuSeqlensQ =
reinterpret_cast<void*>(reinterpret_cast<int32_t*>(cu_seqlens_q->data.dptr));
void* devPtrcuSeqlensKV =
reinterpret_cast<void*>(reinterpret_cast<int32_t*>(cu_seqlens_kv->data.dptr));
void* devPtrDropoutSeed =
reinterpret_cast<void*>(reinterpret_cast<uint64_t*>(rng_state->data.dptr));
void* devPtrDropoutOffset =
reinterpret_cast<void*>(reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_fwd_impl_v1(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM,
devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS,
devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale,
p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO,
devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed,
devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
} else {
NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n");
}
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
}
}
// fused attention BWD FP8 with packed KV
void fused_attn_fp8_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO,
const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP,
const Tensor* output_dQ, const Tensor* output_dKV, const Tensor* cu_seqlens_q,
const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
const DType dO_type = input_dO->data.dtype;
const DType dQKV_type = output_dQ->data.dtype;
void* devPtrQ = input_Q->data.dptr;
void* devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = (typeToNumBits(QKV_type) * head_dim) / 8;
}
void* devPtrK = devPtrKV;
void* devPtrV = static_cast<void*>(static_cast<int8_t*>(devPtrKV) + stride);
void* devPtrDescaleQ = input_Q->scale_inv.dptr;
void* devPtrDescaleK = input_KV->scale_inv.dptr;
void* devPtrDescaleV = input_KV->scale_inv.dptr;
void* devPtrO = input_O->data.dptr;
const DType O_type = input_O->data.dtype;
void* devPtrDescaleO = nullptr;
if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) {
devPtrDescaleO = input_O->scale_inv.dptr;
}
void* devPtrdO = input_dO->data.dptr;
void* devPtrDescaledO = input_dO->scale_inv.dptr;
void* devPtrM = input_M->data.dptr;
void* devPtrZInv = input_ZInv->data.dptr;
void* devPtrScaleS = input_S->scale.dptr;
void* devPtrDescaleS = input_S->scale_inv.dptr;
void* devPtrAmaxdP = input_output_dP->amax.dptr;
void* devPtrScaledP = input_output_dP->scale.dptr;
void* devPtrDescaledP = input_output_dP->scale_inv.dptr;
void* devPtrdQ = output_dQ->data.dptr;
void* devPtrdKV = output_dKV->data.dptr;
void* devPtrdK = devPtrdKV;
void* devPtrdV = static_cast<void*>(static_cast<int8_t*>(devPtrdKV) + stride);
void* devPtrAmaxdQ = output_dQ->amax.dptr;
void* devPtrAmaxdK = output_dKV->amax.dptr;
void* devPtrAmaxdV = output_dKV->amax.dptr;
void* devPtrScaledQ = output_dQ->scale.dptr;
void* devPtrScaledK = output_dKV->scale.dptr;
void* devPtrScaledV = output_dKV->scale.dptr;
void* devPtrcuSeqlensQ =
reinterpret_cast<void*>(reinterpret_cast<int32_t*>(cu_seqlens_q->data.dptr));
void* devPtrcuSeqlensKV =
reinterpret_cast<void*>(reinterpret_cast<int32_t*>(cu_seqlens_kv->data.dptr));
void* devPtrDropoutSeed =
reinterpret_cast<void*>(reinterpret_cast<uint64_t*>(rng_state->data.dptr));
void* devPtrDropoutOffset =
reinterpret_cast<void*>(reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_bwd_impl_v1(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale,
p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv,
devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK,
devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout,
qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ,
devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO,
devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP,
devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK,
devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
} else {
NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n");
}
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
}
}
// fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
......
......@@ -13,47 +13,6 @@
namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV
void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen,
size_t head_dim, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, Tensor *input_output_S, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle);
// fused attention BWD FP8 with packed QKV
void fused_attn_fp8_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M,
const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP,
const Tensor *output_dQKV, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
// fused attention FWD FP8 with packed KV
void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q,
const Tensor *input_KV, Tensor *input_output_S, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
// fused attention BWD FP8 with packed KV
void fused_attn_fp8_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP,
const Tensor *output_dQ, const Tensor *output_dKV, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle);
// fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
......
......@@ -217,6 +217,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
int64_t window_size_right, bool return_max_logit, bool cuda_graph);
/*! \brief Compute dot product attention with packed QKV input.
*
* \deprecated Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead.
*
* Computes:
* - P = Q * Transpose(K) + Bias
......@@ -270,6 +272,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
[[deprecated(
"nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate "
"Q, K, V tensors instead.")]]
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
......@@ -282,6 +287,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
* \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead.
*
* Support Matrix:
\verbatim
......@@ -330,6 +337,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
[[deprecated(
"nvte_fused_attn_bwd_qkvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate "
"Q, K, V tensors instead.")]]
void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S,
NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias,
......@@ -340,6 +350,8 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with packed KV input.
*
* \deprecated Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead.
*
* Computes:
* - P = Q * Transpose(K) + Bias
......@@ -401,6 +413,9 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
[[deprecated(
"nvte_fused_attn_fwd_kvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate "
"Q, K, V tensors instead.")]]
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
......@@ -413,6 +428,8 @@ void nvte_fused_attn_fwd_kvpacked(
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input.
*
* \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead.
*
* Support Matrix:
\verbatim
......@@ -467,6 +484,9 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
[[deprecated(
"nvte_fused_attn_bwd_kvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate "
"Q, K, V tensors instead.")]]
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment