Unverified Commit 7fb22c37 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[C/PyTorch] Add max_t support for THD (#1244)



* WIP: add max_t support for THD
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: save tensors for debug and point to new FE
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix stats in bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix stats in fwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add docstring for DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add docstring
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: first try on adding max_b and max_t
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"

This reverts commit c3d522e9f5aef3c8ddfec5bf6ff24c3db97bb059.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "WIP: first try on adding max_b and max_t"

This reverts commit 3bc01ebaf2aa846fd16634e2d33b0d0f5803a076.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update docstring and fix max_seqlen logic for thd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert two lines of change in docstring
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: add get_max_b/t
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix max_seqlen code and docstring
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* sucess: add max_b/max_t
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove debug code
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change max_b/max_t buckets
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix b vs orig_b
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix b vs orig_b with 0 fill
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE for T3HD/TH3D
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add max_b to conversion kernels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix changes after last merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add Jax support for max_t
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FE to 1.8.0-rc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to 1.8.0
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* code review/formating fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix Stats shape for <9.6
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* return nullptr for offset_stats when cudnn < 9.6
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add more version control
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 83f9cc09
Subproject commit 2533f5e5c1877fd76266133c1479ef1643ce3a8b
Subproject commit 936021bfed8c91dc416af1588b2c4eca631a9e45
......@@ -619,7 +619,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_1": ModelConfig(1, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_1": ModelConfig(3, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_2": ModelConfig(8, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
......
......@@ -272,6 +272,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!");
}
size_t d = input_QKV->data.shape[ndim - 1];
size_t t = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
t = input_QKV->data.shape[0];
}
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
......@@ -292,7 +297,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type,
b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O,
Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace,
stream, handle);
......@@ -349,6 +354,11 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!");
}
size_t d = input_QKV->data.shape[ndim - 1];
size_t t = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
t = input_QKV->data.shape[0];
}
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
......@@ -377,7 +387,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO,
input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded,
input_rng_state, wkspace, stream, handle);
......@@ -442,6 +452,13 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
} else {
NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!");
}
size_t t_q = 0;
size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0];
t_kv = input_KV->data.shape[0];
}
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
......@@ -463,9 +480,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_KV,
input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, is_training, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q,
input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
#else
......@@ -526,6 +543,13 @@ void nvte_fused_attn_bwd_kvpacked(
} else {
NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!");
}
size_t t_q = 0;
size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0];
t_kv = input_KV->data.shape[0];
}
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
......@@ -556,9 +580,9 @@ void nvte_fused_attn_bwd_kvpacked(
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, input_KV,
input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q,
input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
......@@ -616,6 +640,13 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t h_kv = input_K->data.shape[ndim - 2];
size_t d_qk = input_Q->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim - 1];
size_t t_q = 0;
size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0];
t_kv = input_K->data.shape[0];
}
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
......@@ -637,9 +668,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q,
input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right,
input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state,
wkspace, stream, handle);
#else
......@@ -696,6 +727,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t h_kv = input_K->data.shape[ndim - 2];
size_t d_qk = input_Q->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim - 1];
size_t t_q = 0;
size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0];
t_kv = input_K->data.shape[0];
}
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
......@@ -726,10 +764,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q,
input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV,
output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic,
input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK,
output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
......
......@@ -49,14 +49,14 @@ namespace transformer_engine {
namespace fused_attn {
void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
int64_t bias_b, int64_t bias_h, bool is_training, float scaling_factor,
float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ,
void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h,
bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias,
void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ,
void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace,
size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
......@@ -73,10 +73,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (is_training && dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
if (is_ragged) {
const auto cudnn_runtime_version = cudnnGetVersion();
// keep original batch size because cu_seqlens are created with [b+1] shape
int64_t actual_b = b;
if (is_ragged && cudnn_runtime_version >= 90600) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
// replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket
b = max_b;
s_q = max_t_q;
s_kv = max_t_kv;
}
const auto cudnn_runtime_version = cudnnGetVersion();
const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
try {
......@@ -117,6 +125,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_k
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_v
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_o
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_stats
std::shared_ptr<fe::graph::Tensor_attributes>, // dropout_seed
std::shared_ptr<fe::graph::Tensor_attributes>>; // dropout_offset
......@@ -140,9 +149,21 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o;
std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o,
offset_stats;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_Q_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_V_Matrix);
if (is_ragged) {
offset_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_q")
.set_dim({b + 1, 1, 1, 1})
......@@ -158,23 +179,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
offset_o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_o")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_Q_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_V_Matrix);
if (is_ragged) {
Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d_qk})
......@@ -268,6 +272,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_O_Matrix);
if (is_ragged) {
offset_o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_o")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
O->set_output(true)
.set_dim({b, h, s_q, d_v})
.set_stride(o_stride)
......@@ -276,10 +285,24 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride);
}
if (is_ragged && cudnn_runtime_version >= 90600) {
offset_stats =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_stats")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
Stats->set_output(true)
.set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, 1, h, 1})
.set_ragged_offset(offset_stats);
} else {
Stats->set_output(true)
.set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1});
}
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q
std::shared_ptr<fe::graph::Tensor_attributes>, // K
......@@ -291,8 +314,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto padding_tuple =
is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto offset_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o)
auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o)
: std::make_tuple(nullptr, nullptr, nullptr, nullptr);
auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600)
? std::make_tuple(offset_stats)
: std::make_tuple(nullptr);
auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset)
: std::make_tuple(nullptr, nullptr);
......@@ -302,15 +328,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple,
bias_tuple, padding_tuple, offset_tuple, dropout_tuple);
auto return_tuple =
std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple,
padding_tuple, offset_qkvo_tuple, offset_s_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, offset_q, offset_k,
offset_v, offset_o, dropout_seed, dropout_offset] =
offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] =
get_graph(sdpa_f16_fprop_cache, descriptor);
// Exit to request upper level API to allocate memory if needed
......@@ -318,10 +345,17 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
// We do this by adding padding at the end of each separate allocation.
auto plan_workspace_size = alignTo<16>(mha_graph->get_workspace_size());
const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t));
const size_t actual_seqlen_workspace_size = 2 * num_bytes_per_seqlen;
const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0;
const size_t num_bytes_per_ragged_offset =
alignTo<16>((b + 1) * typeToSize(ragged_offset_type));
const size_t seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset;
size_t seqlen_offsets_workspace_size = 0;
if (is_ragged) {
if (cudnn_runtime_version >= 90600) {
seqlen_offsets_workspace_size = 5 * num_bytes_per_ragged_offset;
} else {
seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset;
}
}
if (workspace == nullptr) {
*workspace_size =
plan_workspace_size + actual_seqlen_workspace_size + seqlen_offsets_workspace_size;
......@@ -348,7 +382,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ) + num_bytes_per_seqlen;
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
actual_b, b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenKV));
variant_pack[seq_q] = devActualSeqlenQ;
......@@ -363,15 +397,22 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
void *devOffsetsK = static_cast<int8_t *>(devOffsetsQ) + num_bytes_per_ragged_offset;
void *devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + num_bytes_per_ragged_offset;
void *devOffsetsS = nullptr;
if (cudnn_runtime_version >= 90600) {
devOffsetsS = static_cast<int8_t *>(devOffsetsO) + num_bytes_per_ragged_offset;
}
const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
layout_group, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
devOffsetsV, devOffsetsO);
devOffsetsV, devOffsetsO, devOffsetsS);
variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_k] = devOffsetsK;
variant_pack[offset_v] = devOffsetsV;
variant_pack[offset_o] = devOffsetsO;
if (cudnn_runtime_version >= 90600) {
variant_pack[offset_stats] = devOffsetsS;
}
}
if (is_dropout) {
......@@ -386,12 +427,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ,
void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats,
void *devPtrBias, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO,
void *devPtrdBias, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose,
void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias,
void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
......@@ -414,6 +456,16 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
const auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
// keep original batch size because cu_seqlens are created with [b+1] shape
int64_t actual_b = b;
if (is_ragged && cudnn_runtime_version >= 90600) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
// replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket
b = max_b;
s_q = max_t_q;
s_kv = max_t_kv;
}
// We choose between 32-bit and 64-bit offsets depending on need.
// This allows us to support older cuDNN runtimes gracefully.
......@@ -462,6 +514,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_k
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_v
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_o
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_stats
std::shared_ptr<fe::graph::Tensor_attributes>, // dropout_seed
std::shared_ptr<fe::graph::Tensor_attributes>>; // dropout_offset
......@@ -485,9 +538,24 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes> q, k, v, o, dO, stats, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> bias, dBias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o;
std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o,
offset_stats;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4);
std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_Q_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_V_Matrix);
generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_O_Matrix);
if (is_ragged) {
offset_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_q")
.set_dim({b + 1, 1, 1, 1})
......@@ -508,20 +576,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4);
std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_Q_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_V_Matrix);
generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_O_Matrix);
if (is_ragged) {
q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d_qk})
......@@ -569,11 +623,26 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_dim({b, h, s_q, d_v})
.set_stride(o_stride));
}
if (is_ragged && cudnn_runtime_version >= 90600) {
offset_stats =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_stats")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
stats = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("stats")
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, 1, h, 1})
.set_data_type(fe::DataType_t::FLOAT)
.set_ragged_offset(offset_stats));
} else {
stats = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("stats")
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
}
attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("attn_scale")
......@@ -589,6 +658,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);
if (is_ragged && cudnn_runtime_version >= 90600) {
sdpa_backward_options.set_max_total_seq_len_q(s_q);
}
if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_backward_options.set_sliding_window_length(window_size_left + 1);
}
......@@ -682,8 +755,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr);
auto padding_tuple =
is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto offset_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o)
auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o)
: std::make_tuple(nullptr, nullptr, nullptr, nullptr);
auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600)
? std::make_tuple(offset_stats)
: std::make_tuple(nullptr);
auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset)
: std::make_tuple(nullptr, nullptr);
......@@ -693,15 +769,16 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple,
padding_tuple, offset_tuple, dropout_tuple);
auto return_tuple =
std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple,
offset_qkvo_tuple, offset_s_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv,
offset_q, offset_k, offset_v, offset_o, dropout_seed, dropout_offset] =
offset_q, offset_k, offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] =
get_graph(sdpa_f16_bprop_cache, descriptor);
// Exit to request upper level API to allocate memory if needed
......@@ -709,10 +786,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
// We do this by adding padding at the end of each separate allocation.
auto plan_workspace_size = alignTo<16>(mha_graph->get_workspace_size());
const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t));
const size_t actual_seqlen_workspace_size = 2 * num_bytes_per_seqlen;
const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0;
const size_t num_bytes_per_ragged_offset =
alignTo<16>((b + 1) * typeToSize(ragged_offset_type));
const size_t seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset;
size_t seqlen_offsets_workspace_size = 0;
if (is_ragged) {
if (cudnn_runtime_version >= 90600) {
seqlen_offsets_workspace_size = 5 * num_bytes_per_ragged_offset;
} else {
seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset;
}
}
if (workspace == nullptr) {
*workspace_size =
plan_workspace_size + actual_seqlen_workspace_size + seqlen_offsets_workspace_size;
......@@ -752,7 +836,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ) + num_bytes_per_seqlen;
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
actual_b, b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenKV));
variant_pack[seq_q] = devActualSeqlenQ;
......@@ -767,15 +851,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
void *devOffsetsK = static_cast<int8_t *>(devOffsetsQ) + num_bytes_per_ragged_offset;
void *devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + num_bytes_per_ragged_offset;
void *devOffsetsS = nullptr;
if (cudnn_runtime_version >= 90600) {
devOffsetsS = static_cast<int8_t *>(devOffsetsO) + num_bytes_per_ragged_offset;
}
const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
layout_group, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
devOffsetsV, devOffsetsO);
devOffsetsV, devOffsetsO, devOffsetsS);
variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_k] = devOffsetsK;
variant_pack[offset_v] = devOffsetsV;
variant_pack[offset_o] = devOffsetsO;
if (cudnn_runtime_version >= 90600) {
variant_pack[offset_stats] = devOffsetsS;
}
}
if (is_dropout) {
......@@ -792,10 +883,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
using namespace transformer_engine::fused_attn;
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -803,6 +894,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
const auto QKV_type = input_QKV->data.dtype;
void *devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
......@@ -821,17 +913,30 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr;
size_t max_batch_size = 0;
size_t max_tokens = 0;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch);
max_tokens = get_max_tokens(num_tokens);
}
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
......@@ -845,7 +950,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
......@@ -875,12 +984,12 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, bias_b,
bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets,
devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream,
handle);
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim,
max_batch_size, max_tokens, max_tokens, bias_b, bias_h, is_training, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -898,10 +1007,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
}
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
......@@ -909,7 +1018,6 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
const auto QKV_type = input_QKV->data.dtype;
void *devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
......@@ -934,6 +1042,14 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
bias_h = output_dBias->data.shape[1];
}
size_t max_batch_size = 0;
size_t max_tokens = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch);
max_tokens = get_max_tokens(num_tokens);
}
void *devPtrdQKV = output_dQKV->data.dptr;
void *devPtrdQ = devPtrdQKV;
void *devPtrdK = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + stride);
......@@ -952,12 +1068,13 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, bias_b,
bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim,
max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, devPtrK,
devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens,
devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -975,19 +1092,21 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
}
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
......@@ -1005,6 +1124,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
......@@ -1013,12 +1133,26 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
size_t max_batch_size = 0;
size_t max_tokens_q = 0;
size_t max_tokens_kv = 0;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch);
max_tokens_q = get_max_tokens(num_tokens_q);
max_tokens_kv = get_max_tokens(num_tokens_kv);
}
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
......@@ -1032,7 +1166,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
......@@ -1063,11 +1201,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim,
bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale,
p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ,
devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1086,12 +1224,13 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -1122,6 +1261,16 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
bias_h = output_dBias->data.shape[1];
}
size_t max_batch_size = 0;
size_t max_tokens_q = 0;
size_t max_tokens_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch);
max_tokens_q = get_max_tokens(num_tokens_q);
max_tokens_kv = get_max_tokens(num_tokens_kv);
}
void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdKV = output_dKV->data.dptr;
void *devPtrdK = devPtrdKV;
......@@ -1143,12 +1292,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim,
bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ,
devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ,
devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV,
devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1167,8 +1316,9 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
......@@ -1177,6 +1327,7 @@ void fused_attn_arbitrary_seqlen_fwd(
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr;
void *devPtrV = input_V->data.dptr;
......@@ -1196,12 +1347,26 @@ void fused_attn_arbitrary_seqlen_fwd(
void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
size_t max_batch_size = 0;
size_t max_tokens_q = 0;
size_t max_tokens_kv = 0;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch);
max_tokens_q = get_max_tokens(num_tokens_q);
max_tokens_kv = get_max_tokens(num_tokens_kv);
}
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
......@@ -1215,7 +1380,11 @@ void fused_attn_arbitrary_seqlen_fwd(
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
......@@ -1246,11 +1415,11 @@ void fused_attn_arbitrary_seqlen_fwd(
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale,
p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ,
devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1269,13 +1438,13 @@ void fused_attn_arbitrary_seqlen_fwd(
void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -1296,6 +1465,16 @@ void fused_attn_arbitrary_seqlen_bwd(
bias_h = output_dBias->data.shape[1];
}
size_t max_batch_size = 0;
size_t max_tokens_q = 0;
size_t max_tokens_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch);
max_tokens_q = get_max_tokens(num_tokens_q);
max_tokens_kv = get_max_tokens(num_tokens_kv);
}
void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdK = output_dK->data.dptr;
void *devPtrdV = output_dV->data.dptr;
......@@ -1315,12 +1494,12 @@ void fused_attn_arbitrary_seqlen_bwd(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ,
devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ,
devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV,
devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......
......@@ -19,47 +19,50 @@
namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
......@@ -68,13 +71,13 @@ void fused_attn_arbitrary_seqlen_fwd(
void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
......
......@@ -746,7 +746,7 @@ void fused_attn_max_512_fwd_impl(
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devPtrCuSeqlenQ),
b, b, static_cast<const int32_t *>(devPtrCuSeqlenQ),
static_cast<const int32_t *>(devPtrCuSeqlenKV), static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenK));
NVTE_CHECK_CUDA(cudaGetLastError());
......@@ -1169,7 +1169,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devPtrCuSeqlenQ),
b, b, static_cast<const int32_t *>(devPtrCuSeqlenQ),
static_cast<const int32_t *>(devPtrCuSeqlenKV), static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenK));
NVTE_CHECK_CUDA(cudaGetLastError());
......
......@@ -5,6 +5,7 @@
************************************************************************/
#include <algorithm>
#include <cmath>
#include "../common.h"
#include "transformer_engine/fused_attn.h"
......@@ -353,66 +354,75 @@ __global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t *
}
// convert cu_seqlens to actual_seqlens
__global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu_seqlens,
__global__ void cu_seqlens_to_actual_seqlens(int64_t actual_b, int64_t max_b,
int32_t const *const q_cu_seqlens,
int32_t const *const kv_cu_seqlens, int32_t *q_seqlens,
int32_t *kv_seqlens) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < b) {
if (tid < actual_b) {
q_seqlens[tid] = q_cu_seqlens[tid + 1] - q_cu_seqlens[tid];
kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid];
} else if (tid < max_b) {
q_seqlens[tid] = 0;
kv_seqlens[tid] = 0;
}
}
// convert cu_seqlens_padded to offsets
template <class OFFSETS_T>
__device__ void cu_seqlens_padded_to_offsets_impl(NVTE_QKV_Layout_Group layout_group, int64_t b,
int64_t h, int64_t hg, int64_t d_qk, int64_t d_v,
const int32_t *cu_seqlens_q_padded,
const int32_t *cu_seqlens_kv_padded,
OFFSETS_T *offsets_q, OFFSETS_T *offsets_k,
OFFSETS_T *offsets_v, OFFSETS_T *offsets_o) {
__device__ void cu_seqlens_padded_to_offsets_impl(
NVTE_QKV_Layout_Group layout_group, int64_t actual_b, int64_t max_b, int64_t h, int64_t hg,
int64_t d_qk, int64_t d_v, const int32_t *cu_seqlens_q_padded,
const int32_t *cu_seqlens_kv_padded, OFFSETS_T *offsets_q, OFFSETS_T *offsets_k,
OFFSETS_T *offsets_v, OFFSETS_T *offsets_o, OFFSETS_T *offsets_s) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < b + 1) {
offsets_o[tid] = h * d_v * cu_seqlens_q_padded[tid];
auto cu_seqlens_id = min(tid, actual_b);
if (tid <= max_b) {
offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id];
if (offsets_s != nullptr) {
offsets_s[tid] = h * cu_seqlens_q_padded[cu_seqlens_id];
}
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid];
offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[tid];
offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[tid];
offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id];
offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id];
offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id];
break;
case NVTE_QKV_Layout_Group::NVTE_3HD:
case NVTE_QKV_Layout_Group::NVTE_H3D:
offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[tid];
offsets_k[tid] = offsets_q[tid];
offsets_v[tid] = offsets_q[tid];
offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id];
offsets_k[tid] = offsets_q[cu_seqlens_id];
offsets_v[tid] = offsets_q[cu_seqlens_id];
break;
case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid];
offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[tid];
offsets_v[tid] = offsets_k[tid];
offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id];
offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id];
offsets_v[tid] = offsets_k[cu_seqlens_id];
break;
}
}
}
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t b,
int64_t h, int64_t hg, int64_t d_qk, int64_t d_v,
const int32_t *cu_seqlens_q_padded,
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t actual_b,
int64_t max_b, int64_t h, int64_t hg, int64_t d_qk,
int64_t d_v, const int32_t *cu_seqlens_q_padded,
const int32_t *cu_seqlens_kv_padded,
DType offset_dtype, void *offsets_q, void *offsets_k,
void *offsets_v, void *offsets_o) {
void *offsets_v, void *offsets_o, void *offsets_s) {
if (offset_dtype == DType::kInt32) {
cu_seqlens_padded_to_offsets_impl<int32_t>(
layout_group, b, h, hg, d_qk, d_v, cu_seqlens_q_padded, cu_seqlens_kv_padded,
layout_group, actual_b, max_b, h, hg, d_qk, d_v, cu_seqlens_q_padded, cu_seqlens_kv_padded,
reinterpret_cast<int32_t *>(offsets_q), reinterpret_cast<int32_t *>(offsets_k),
reinterpret_cast<int32_t *>(offsets_v), reinterpret_cast<int32_t *>(offsets_o));
reinterpret_cast<int32_t *>(offsets_v), reinterpret_cast<int32_t *>(offsets_o),
reinterpret_cast<int32_t *>(offsets_s));
} else {
assert(offset_dtype == DType::kInt64 && "expect int64");
cu_seqlens_padded_to_offsets_impl<int64_t>(
layout_group, b, h, hg, d_qk, d_v, cu_seqlens_q_padded, cu_seqlens_kv_padded,
layout_group, actual_b, max_b, h, hg, d_qk, d_v, cu_seqlens_q_padded, cu_seqlens_kv_padded,
reinterpret_cast<int64_t *>(offsets_q), reinterpret_cast<int64_t *>(offsets_k),
reinterpret_cast<int64_t *>(offsets_v), reinterpret_cast<int64_t *>(offsets_o));
reinterpret_cast<int64_t *>(offsets_v), reinterpret_cast<int64_t *>(offsets_o),
reinterpret_cast<int64_t *>(offsets_s));
}
}
......@@ -450,6 +460,40 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at
return DType::kInt32;
}
// quantize batch size
size_t get_max_batch_size(size_t batch_size) {
size_t max_b = batch_size;
size_t log2_b = ceil(log2(batch_size));
// batch size is expected to be 10s-100s
// b = 1, ..., 32 -> max_b = 32
// b = 33, ..., 512 -> max_b = next power of 2
// otherwise -> max_b = b
if (log2_b <= 5) {
max_b = 32;
} else if (log2_b <= 9) {
max_b = pow(2, log2_b);
}
return max_b;
}
// quantize token count
size_t get_max_tokens(size_t num_tokens) {
// token count is expected to be 1k's-100k's
// t = 0, ..., 1024 -> max_t = 1024
// t = 1025, ..., 32k -> max_t = next power of 2
// t = 32k+1, ... -> max_t = increment by 32k
size_t log2_t = ceil(log2(num_tokens));
size_t max_t = 0;
if (log2_t <= 10) {
max_t = 1024;
} else if (log2_t <= 15) {
max_t = pow(2, log2_t);
} else {
max_t = (num_tokens + 32767) / 32768 * 32768;
}
return max_t;
}
} // namespace fused_attn
// get cuDNN data type
......
......@@ -122,21 +122,24 @@ __global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t *
int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset,
int32_t *o_ragged_offset);
__global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu_seqlens,
__global__ void cu_seqlens_to_actual_seqlens(int64_t actual_b, int64_t max_b,
int32_t const *const q_cu_seqlens,
int32_t const *const kv_cu_seqlens, int32_t *q_seqlens,
int32_t *kv_seqlens);
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t b,
int64_t h, int64_t hg, int64_t d_qk, int64_t d_v,
const int32_t *cu_seqlens_q_padded,
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t actual_b,
int64_t max_b, int64_t h, int64_t hg, int64_t d_qk,
int64_t d_v, const int32_t *cu_seqlens_q_padded,
const int32_t *cu_seqlens_kv_padded,
DType offset_dtype, void *offsets_q, void *offsets_k,
void *offsets_v, void *offsets_o);
void *offsets_v, void *offsets_o, void *offsets_s);
DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_attn_heads,
int64_t num_gqa_groups, int64_t max_seqlen_q, int64_t max_seqlen_kv,
int64_t head_dim_qk, int64_t head_dim_v);
size_t get_max_batch_size(size_t batch_size);
size_t get_max_tokens(size_t num_tokens);
} // namespace fused_attn
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
......
......@@ -277,7 +277,16 @@ class FusedAttnFwdPrimitive(BasePrimitive):
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, config.max_segments_per_seq)
# cuDNN 9.6 reduces the required softmax shape
if get_cudnn_version() >= (9, 6, 0):
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
else:
softmax_shape = (
*batch_shape,
attn_heads,
q_max_seqlen,
config.max_segments_per_seq,
)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f"Unsupported {backend=}")
......
......@@ -7671,6 +7671,60 @@ class DotProductAttention(TransformerEngineBaseModule):
based on its internal logic. These optimizations trade memory for performance
and should be used with care.
.. note::
.. _cu_seqlens note:
When training data has variable sequence lengths, users have two options.
1. Manipulate the data and pad all sequences to the same length. Use
:attr:`qkv_format` = {"bshd", "sbhd"} and
:attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}.
Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`
(which will be converted to :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`), to provide
the real sequence length information. For example, a batch of 3 sequences
[a a a b b c c c c] can be padded to [a a a PAD b b PAD PAD c c c c], and the cumulative
sequence length tensors would be
:attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention.
2. Do not perform padding on training data. Use :attr:`qkv_format` = "thd" and
:attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}.
Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`,
as in option 1. For example, a batch of 3 sequences [a a a b b c c c c] can be processed
without any padding, and the sequence length tensors would be
:attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention.
In certain use cases, a varying number of identifier tokens are inserted between
sequences. These tokens do not participate in the attention calculation.
:attr:`cu_seqlens_q_padded` and :attr:`cu_seqlens_kv_padded` must be specified
in such cases to correctly identify the start and end of each sequence in a batch.
For example, a batch of 3 sequences [a a a 1 b b 2 2 c c c c 3] would have
:attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9], and
:attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_kv_padded` = [0, 4, 8, 13]
for self-attention.
.. note::
.. _max_seqlen note:
When :attr:`qkv_format` = {"bshd", "sbhd"}, sequences are of equal length in a batch.
:attr:`max_seqlen_q` and :attr:`max_seqlen_kv` should be the same as the "s" dimension of
:attr:`query_layer` and :attr:`key_layer` tensors. When unset, Transformer Engine will
infer them as such.
When :attr:`qkv_format` = "thd", sequences have varying lengths. :attr:`max_seqlen_q` and
:attr:`max_seqlen_kv` should be the maximum query and key/value sequence length in a batch.
When unset, Transformer Engine deduces them from :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`.
This deduction costs a small kernel and some CPU-GPU synchronization, and to avoid this
overhead, users are recommended to obtain the maximum sequence lengths from the data loaders
and pass them in.
- As the maximum sequence lengths, batch size, and number of tokens change from batch to batch,
dynamic shapes need to be supported for tensor construction. FlashAttention and
UnfusedDotProductAttention naturally do so, while FusedAttention requires parameters to be static
to create graphs before performance heuristics analysis. To reduce the number of graphs created
per run, Transformer Engine 1.13+ quantizes relevant parameters: for cuDNN < 9.6, {batch size,
:attr:`max_seqlen_q`, :attr:`max_seqlen_kv`}, and for cuDNN >= 9.6, {"t" dimension of
:attr:`query_layer`, "t" dimension of :attr:`key_layer`}.
Parameters
----------
query_layer : torch.Tensor
......@@ -7693,25 +7747,29 @@ class DotProductAttention(TransformerEngineBaseModule):
cu_seqlens_q: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
See :ref:`note<cu_seqlens note>` for more details.
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
See :ref:`note<cu_seqlens note>` for more details.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for
`query_layer`, with shape [batch_size + 1] and dtype torch.int32.
When there is no padding between sequences in a batch,
`cu_seqlens_q_padded = cu_seqlens_q`.
See :ref:`note<cu_seqlens note>` for more details.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
When there is no padding between sequences in a batch,
`cu_seqlens_kv_padded = cu_seqlens_kv`.
See :ref:`note<cu_seqlens note>` for more details.
max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided.
See :ref:`note<max_seqlen note>` for more details.
max_seqlen_kv: Optional[int], default = `None`
Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv` if not provided.
See :ref:`note<max_seqlen note>` for more details.
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',
'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right',
'arbitrary'}, default = `None`. Type of attention mask passed into
......@@ -7902,6 +7960,7 @@ class DotProductAttention(TransformerEngineBaseModule):
assert (
cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
batch_size = len(cu_seqlens_q) - 1
if max_seqlen_q is None:
if cu_seqlens_q_padded is not None:
seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]
......@@ -7914,7 +7973,6 @@ class DotProductAttention(TransformerEngineBaseModule):
else:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64)
batch_size = len(cu_seqlens_q) - 1
cp_size = 1
if isinstance(self.cp_group, dist_group_type):
......@@ -7929,10 +7987,12 @@ class DotProductAttention(TransformerEngineBaseModule):
len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
if qkv_format == "sbhd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q
max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv
batch_size = query_layer.shape[1]
else:
max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q
max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv
batch_size = query_layer.shape[0]
max_seqlen_q *= cp_size
max_seqlen_kv *= cp_size
......@@ -7941,13 +8001,13 @@ class DotProductAttention(TransformerEngineBaseModule):
assert all(
seqlens_q <= max_seqlen_q
), """Sequence lengths indicated by cu_seqlens_q must be no greater than
the sequence dimention in 'query_layer'!"""
the sequence dimension in 'query_layer'!"""
if cu_seqlens_kv is not None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
assert all(
seqlens_kv <= max_seqlen_kv
), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
the sequence dimention in 'key_layer' and 'value_layer'!"""
the sequence dimension in 'key_layer' and 'value_layer'!"""
if cu_seqlens_q is None or cu_seqlens_kv is None:
if "padding" in attn_mask_type:
assert (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment