Unverified Commit 3454f84d authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[common] Remove kvpacked and qkvpacked attention functions for every kernel type. (#2287)



* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* depracted compile time warning + \warning -> \deprecated
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent d20311bd
...@@ -18,53 +18,6 @@ ...@@ -18,53 +18,6 @@
namespace transformer_engine { namespace transformer_engine {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, bool return_max_logit, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens,
const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
......
...@@ -1215,150 +1215,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -1215,150 +1215,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
} // namespace fused_attn } // namespace fused_attn
using namespace transformer_engine::fused_attn; using namespace transformer_engine::fused_attn;
void fused_attn_max_512_fwd_qkvpacked(
size_t batch, size_t num_head, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
const auto stride = 2 * num_head * head_dim;
void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void *devPtrBias = static_cast<void *>(input_Bias->data.dptr);
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen};
output_S->data.dtype = input_QKV->data.dtype;
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void *devPtrCuSeqlen = cu_seqlens->data.dptr;
const DType rng_state_type = rng_state->data.dtype;
NVTE_CHECK(rng_state_type == DType::kInt64);
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
static_cast<void *>(static_cast<uint64_t *>(rng_state->data.dptr) + 1);
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn_max_512_fwd_impl(
batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias,
devPtrCuSeqlen, devPtrCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(QKV_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS,
"NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512.");
// Q shape is [b, s, h, d]
void *devPtrQ = input_Q->data.dptr;
// KV shape is [b, s, 2, h, d]
const auto stride = 2 * num_head * head_dim;
void *devPtrK = input_KV->data.dptr;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrK) + stride);
void *devPtrBias = input_Bias->data.dptr;
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
const DType q_type = input_Q->data.dtype;
const DType kv_type = input_KV->data.dtype;
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
output_S->data.dtype = q_type;
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void *devQCuSeqlen = q_cu_seqlens->data.dptr;
void *devKVCuSeqlen = kv_cu_seqlens->data.dptr;
const DType rng_state_type = rng_state->data.dtype;
NVTE_CHECK(rng_state_type == DType::kInt64);
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
static_cast<void *>(static_cast<uint64_t *>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
fused_attn_max_512_fwd_impl(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias,
devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(q_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, bool is_training, size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
...@@ -1429,126 +1285,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, ...@@ -1429,126 +1285,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
} }
} }
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV,
Tensor *output_dBias, const Tensor *cu_seqlens,
Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
auto stride = 2 * num_head * head_dim;
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void *devPtrdO = input_dO->data.dptr;
// dQKV shape is [b, s, 3, h, d]
void *devPtrdQKV = output_dQKV->data.dptr;
void *devPtrdQ = devPtrdQKV;
void *devPtrdK = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + stride);
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + 2 * stride);
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrS = output_S->data.dptr;
// devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS;
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
const auto qkv_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn_max_512_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, attn_scale,
p_dropout, qkv_layout, mask_type, bias_type, devPtrQ, devPtrK,
devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdS,
devPtrdBias, devPtrCuSeqlens, devPtrCuSeqlens, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(qkv_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
// Q shape is [b, s, h, d]
// KV shape is [b, s, 2, h, d]
auto stride = 2 * num_head * head_dim;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_KV->data.dptr;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrK) + stride);
void *devPtrdO = input_dO->data.dptr;
// dQ shape is [b, s, h, d]
// dKV shape is [b, s, 2, h, d]
void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdK = output_dKV->data.dptr;
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdK) + stride);
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrS = output_S->data.dptr;
// devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS;
void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr;
void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr;
const auto q_type = input_Q->data.dtype;
const auto kv_type = input_KV->data.dtype;
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
size_t workspace_size = 0;
fused_attn_max_512_bwd_impl(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout,
mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV,
devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(q_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, float attn_scale, size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......
...@@ -18,25 +18,6 @@ ...@@ -18,25 +18,6 @@
namespace transformer_engine { namespace transformer_engine {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_size, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, bool is_training, size_t kv_max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
...@@ -47,24 +28,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, ...@@ -47,24 +28,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace, const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle); cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV,
Tensor *output_dBias, const Tensor *cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t head_dim, float attn_scale, size_t kv_max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......
...@@ -13,47 +13,6 @@ ...@@ -13,47 +13,6 @@
namespace transformer_engine { namespace transformer_engine {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV
void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen,
size_t head_dim, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, Tensor *input_output_S, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle);
// fused attention BWD FP8 with packed QKV
void fused_attn_fp8_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M,
const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP,
const Tensor *output_dQKV, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
// fused attention FWD FP8 with packed KV
void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q,
const Tensor *input_KV, Tensor *input_output_S, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
// fused attention BWD FP8 with packed KV
void fused_attn_fp8_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP,
const Tensor *output_dQ, const Tensor *output_dKV, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle);
// fused attention FWD FP8 with separate Q, K, V // fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
......
...@@ -217,6 +217,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -217,6 +217,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
int64_t window_size_right, bool return_max_logit, bool cuda_graph); int64_t window_size_right, bool return_max_logit, bool cuda_graph);
/*! \brief Compute dot product attention with packed QKV input. /*! \brief Compute dot product attention with packed QKV input.
*
* \deprecated Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead.
* *
* Computes: * Computes:
* - P = Q * Transpose(K) + Bias * - P = Q * Transpose(K) + Bias
...@@ -270,6 +272,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -270,6 +272,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
[[deprecated(
"nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate "
"Q, K, V tensors instead.")]]
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
...@@ -282,6 +287,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -282,6 +287,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTETensor workspace, cudaStream_t stream); NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input. /*! \brief Compute the backward of the dot product attention with packed QKV input.
*
* \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead.
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
...@@ -330,6 +337,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -330,6 +337,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
[[deprecated(
"nvte_fused_attn_bwd_qkvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate "
"Q, K, V tensors instead.")]]
void nvte_fused_attn_bwd_qkvpacked( void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S,
NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias,
...@@ -340,6 +350,8 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -340,6 +350,8 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor workspace, cudaStream_t stream); NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with packed KV input. /*! \brief Compute dot product attention with packed KV input.
*
* \deprecated Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead.
* *
* Computes: * Computes:
* - P = Q * Transpose(K) + Bias * - P = Q * Transpose(K) + Bias
...@@ -401,6 +413,9 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -401,6 +413,9 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
[[deprecated(
"nvte_fused_attn_fwd_kvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate "
"Q, K, V tensors instead.")]]
void nvte_fused_attn_fwd_kvpacked( void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
...@@ -413,6 +428,8 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -413,6 +428,8 @@ void nvte_fused_attn_fwd_kvpacked(
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input. /*! \brief Compute the backward of the dot product attention with packed KV input.
*
* \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead.
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
...@@ -467,6 +484,9 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -467,6 +484,9 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
[[deprecated(
"nvte_fused_attn_bwd_kvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate "
"Q, K, V tensors instead.")]]
void nvte_fused_attn_bwd_kvpacked( void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment