Unverified Commit 3454f84d authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[common] Remove kvpacked and qkvpacked attention functions for every kernel type. (#2287)



* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* depracted compile time warning + \warning -> \deprecated
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent d20311bd
......@@ -18,53 +18,6 @@
namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, bool return_max_logit, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens,
const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
......
......@@ -1215,150 +1215,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
} // namespace fused_attn
using namespace transformer_engine::fused_attn;
void fused_attn_max_512_fwd_qkvpacked(
size_t batch, size_t num_head, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
const auto stride = 2 * num_head * head_dim;
void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void *devPtrBias = static_cast<void *>(input_Bias->data.dptr);
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen};
output_S->data.dtype = input_QKV->data.dtype;
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void *devPtrCuSeqlen = cu_seqlens->data.dptr;
const DType rng_state_type = rng_state->data.dtype;
NVTE_CHECK(rng_state_type == DType::kInt64);
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
static_cast<void *>(static_cast<uint64_t *>(rng_state->data.dptr) + 1);
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn_max_512_fwd_impl(
batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias,
devPtrCuSeqlen, devPtrCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(QKV_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS,
"NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512.");
// Q shape is [b, s, h, d]
void *devPtrQ = input_Q->data.dptr;
// KV shape is [b, s, 2, h, d]
const auto stride = 2 * num_head * head_dim;
void *devPtrK = input_KV->data.dptr;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrK) + stride);
void *devPtrBias = input_Bias->data.dptr;
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
const DType q_type = input_Q->data.dtype;
const DType kv_type = input_KV->data.dtype;
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
output_S->data.dtype = q_type;
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void *devQCuSeqlen = q_cu_seqlens->data.dptr;
void *devKVCuSeqlen = kv_cu_seqlens->data.dptr;
const DType rng_state_type = rng_state->data.dtype;
NVTE_CHECK(rng_state_type == DType::kInt64);
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
static_cast<void *>(static_cast<uint64_t *>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
fused_attn_max_512_fwd_impl(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias,
devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(q_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
......@@ -1429,126 +1285,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
}
}
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV,
Tensor *output_dBias, const Tensor *cu_seqlens,
Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
auto stride = 2 * num_head * head_dim;
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void *devPtrdO = input_dO->data.dptr;
// dQKV shape is [b, s, 3, h, d]
void *devPtrdQKV = output_dQKV->data.dptr;
void *devPtrdQ = devPtrdQKV;
void *devPtrdK = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + stride);
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + 2 * stride);
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrS = output_S->data.dptr;
// devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS;
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
const auto qkv_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn_max_512_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, attn_scale,
p_dropout, qkv_layout, mask_type, bias_type, devPtrQ, devPtrK,
devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdS,
devPtrdBias, devPtrCuSeqlens, devPtrCuSeqlens, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(qkv_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
// Q shape is [b, s, h, d]
// KV shape is [b, s, 2, h, d]
auto stride = 2 * num_head * head_dim;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_KV->data.dptr;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrK) + stride);
void *devPtrdO = input_dO->data.dptr;
// dQ shape is [b, s, h, d]
// dKV shape is [b, s, 2, h, d]
void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdK = output_dKV->data.dptr;
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdK) + stride);
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrS = output_S->data.dptr;
// devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS;
void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr;
void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr;
const auto q_type = input_Q->data.dtype;
const auto kv_type = input_KV->data.dtype;
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
size_t workspace_size = 0;
fused_attn_max_512_bwd_impl(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout,
mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV,
devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(q_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......
......@@ -18,25 +18,6 @@
namespace transformer_engine {
#if (CUDNN_VERSION >= 8901)
void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_size, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
......@@ -47,24 +28,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV,
Tensor *output_dBias, const Tensor *cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......
......@@ -13,47 +13,6 @@
namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV
void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen,
size_t head_dim, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, Tensor *input_output_S, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle);
// fused attention BWD FP8 with packed QKV
void fused_attn_fp8_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M,
const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP,
const Tensor *output_dQKV, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
// fused attention FWD FP8 with packed KV
void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q,
const Tensor *input_KV, Tensor *input_output_S, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
// fused attention BWD FP8 with packed KV
void fused_attn_fp8_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP,
const Tensor *output_dQ, const Tensor *output_dKV, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle);
// fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
......
......@@ -217,6 +217,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
int64_t window_size_right, bool return_max_logit, bool cuda_graph);
/*! \brief Compute dot product attention with packed QKV input.
*
* \deprecated Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead.
*
* Computes:
* - P = Q * Transpose(K) + Bias
......@@ -270,6 +272,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
[[deprecated(
"nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate "
"Q, K, V tensors instead.")]]
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
......@@ -282,6 +287,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
* \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead.
*
* Support Matrix:
\verbatim
......@@ -330,6 +337,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
[[deprecated(
"nvte_fused_attn_bwd_qkvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate "
"Q, K, V tensors instead.")]]
void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S,
NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias,
......@@ -340,6 +350,8 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with packed KV input.
*
* \deprecated Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead.
*
* Computes:
* - P = Q * Transpose(K) + Bias
......@@ -401,6 +413,9 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
[[deprecated(
"nvte_fused_attn_fwd_kvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate "
"Q, K, V tensors instead.")]]
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
......@@ -413,6 +428,8 @@ void nvte_fused_attn_fwd_kvpacked(
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input.
*
* \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead.
*
* Support Matrix:
\verbatim
......@@ -467,6 +484,9 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
[[deprecated(
"nvte_fused_attn_bwd_kvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate "
"Q, K, V tensors instead.")]]
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment