Unverified Commit 70d3251f authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[C/PyTorch] Simplify THD offset tensors (#927)



* simplify offset tensors
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes; tests pass
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix C lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace with_offset with with_padding
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace with_padding with padded
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes after merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix for fused attn fwd/bwd calls
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix Jax
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* adjust spacing in docstring
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix pytorch tests; fix paddle api
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix attn_biases
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix AttnFuncWithCP backward
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix attn with CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix paddle
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 94a426b0
......@@ -832,24 +832,6 @@ def _run_dot_product_attention(
inp[i].requires_grad = True
inp_orig[i].requires_grad = True
# Create ragged offsets for q/k/v
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = None, None, None, None
qkv_group = "".join([x for x in qkv_layout if x not in "bst"])
if qkv_format == "thd":
seq_offsets_o = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
if qkv_group == "hd_hd_hd":
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
if qkv_group in ["3hd", "h3d"]:
seq_offsets_q = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_k = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_v = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
if qkv_group in ["hd_2hd", "hd_h2d"]:
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
# Create output gradient
qkv_format_kv = "_".join(qkv_format)
qkv_format_kv = qkv_format_kv.replace("s", "sq")
......@@ -928,10 +910,8 @@ def _run_dot_product_attention(
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
......@@ -1957,8 +1937,6 @@ class _custom_mha_fp8(torch.autograd.Function):
None,
None,
None,
None,
None,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
......@@ -2038,8 +2016,6 @@ class _custom_mha_fp8(torch.autograd.Function):
FusedAttnBackend["FP8"],
None,
None,
None,
None,
fwd_scale_inverses[META_QKV], # d_scale_qkv,
fwd_scale_inverses[META_S], # d_scale_s,
fwd_scale_inverses[META_O], # d_scale_o,
......
......@@ -8,9 +8,11 @@ import subprocess
from test_fused_attn import (
ModelConfig,
_is_flash_attention_2_available,
_cudnn_version,
)
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
get_cudnn_version,
)
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
......@@ -58,7 +60,7 @@ model_configs_fused_attn = {
}
@pytest.mark.skipif(_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
......
......@@ -196,21 +196,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// NVTE fused attention FWD with packed QKV
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o, const NVTETensor rng_state,
size_t max_seqlen, bool is_training, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_cu_seqlens_padded = reinterpret_cast<const Tensor *>(cu_seqlens_padded);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV);
const Tensor *input_Bias = reinterpret_cast<const Tensor *>(Bias);
......@@ -252,8 +247,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, wkspace, stream, handle);
input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -272,21 +266,19 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
}
}
// NVTE fused attention BWD with packed QKV
void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S,
NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias,
const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) {
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_cu_seqlens_padded = reinterpret_cast<const Tensor *>(cu_seqlens_padded);
const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor *>(dO);
......@@ -338,8 +330,7 @@ void nvte_fused_attn_bwd_qkvpacked(
fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV,
input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, wkspace, stream, handle);
input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
......@@ -366,21 +357,18 @@ void nvte_fused_attn_bwd_qkvpacked(
void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o,
const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV);
......@@ -426,8 +414,8 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_seq_offsets_q, input_seq_offsets_k,
input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, stream, handle);
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -450,18 +438,16 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) {
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O);
......@@ -519,9 +505,9 @@ void nvte_fused_attn_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_KV, input_O, input_dO, input_Bias, output_S, output_dQ,
output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_seq_offsets_q,
input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace,
stream, handle);
output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
#else
const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention "
......@@ -549,9 +535,8 @@ void nvte_fused_attn_bwd_kvpacked(
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o, const NVTETensor rng_state,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
......@@ -560,10 +545,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K);
......@@ -601,8 +584,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_seq_offsets_q, input_seq_offsets_k,
input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, stream, handle);
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -625,20 +608,17 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K);
const Tensor *input_V = reinterpret_cast<const Tensor *>(V);
......@@ -690,8 +670,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S,
output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, wkspace, stream, handle);
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
......
......@@ -53,9 +53,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ,
void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsK,
void *devPtrSeqOffsetsV, void *devPtrSeqOffsetsO, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
......@@ -297,8 +297,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
auto plan_workspace_size = mha_graph->get_workspace_size();
// Exit to request upper level API to allocate memory if needed
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
size_t seqlen_offsets_workspace_size = 4 * (b + 1) * sizeof(int32_t);
if (workspace == nullptr) {
*workspace_size = plan_workspace_size + actual_seqlen_workspace_size;
*workspace_size =
plan_workspace_size + actual_seqlen_workspace_size + seqlen_offsets_workspace_size;
return;
}
......@@ -330,17 +332,29 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
}
if (is_ragged) {
variant_pack[offset_q] = devPtrSeqOffsetsQ;
variant_pack[offset_k] = devPtrSeqOffsetsK;
variant_pack[offset_v] = devPtrSeqOffsetsV;
variant_pack[offset_o] = devPtrSeqOffsetsO;
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block) / nthreads_per_block;
void *devOffsetsQ =
static_cast<int8_t *>(workspace) + plan_workspace_size + actual_seqlen_workspace_size;
void *devOffsetsK = static_cast<int8_t *>(devOffsetsQ) + (b + 1) * sizeof(int32_t);
void *devOffsetsV = static_cast<int8_t *>(devOffsetsK) + (b + 1) * sizeof(int32_t);
void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + (b + 1) * sizeof(int32_t);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
layout_group, b, h, hg, d, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), static_cast<int32_t *>(devOffsetsQ),
static_cast<int32_t *>(devOffsetsK), static_cast<int32_t *>(devOffsetsV),
static_cast<int32_t *>(devOffsetsO));
variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_k] = devOffsetsK;
variant_pack[offset_v] = devOffsetsV;
variant_pack[offset_o] = devOffsetsO;
}
if (is_dropout) {
variant_pack[dropout_seed] = devPtrDropoutSeed;
variant_pack[dropout_offset] = devPtrDropoutOffset;
}
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
......@@ -354,9 +368,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias,
void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsK,
void *devPtrSeqOffsetsV, void *devPtrSeqOffsetsO, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
......@@ -366,9 +380,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
if (is_ragged) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
}
try {
FADescriptor_v1 descriptor{b,
......@@ -646,8 +657,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
// Exit to request upper level API to allocate memory if needed
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
size_t seqlen_offsets_workspace_size = 4 * (b + 1) * sizeof(int32_t);
if (workspace == nullptr) {
*workspace_size = plan_workspace_size + actual_seqlen_workspace_size;
*workspace_size =
plan_workspace_size + actual_seqlen_workspace_size + seqlen_offsets_workspace_size;
return;
}
......@@ -692,10 +705,23 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
}
if (is_ragged) {
variant_pack[offset_q] = devPtrSeqOffsetsQ;
variant_pack[offset_k] = devPtrSeqOffsetsK;
variant_pack[offset_v] = devPtrSeqOffsetsV;
variant_pack[offset_o] = devPtrSeqOffsetsO;
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block) / nthreads_per_block;
void *devOffsetsQ =
static_cast<int8_t *>(workspace) + plan_workspace_size + actual_seqlen_workspace_size;
void *devOffsetsK = static_cast<int8_t *>(devOffsetsQ) + (b + 1) * sizeof(int32_t);
void *devOffsetsV = static_cast<int8_t *>(devOffsetsK) + (b + 1) * sizeof(int32_t);
void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + (b + 1) * sizeof(int32_t);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
layout_group, b, h, hg, d, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), static_cast<int32_t *>(devOffsetsQ),
static_cast<int32_t *>(devOffsetsK), static_cast<int32_t *>(devOffsetsV),
static_cast<int32_t *>(devOffsetsO));
variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_k] = devOffsetsK;
variant_pack[offset_v] = devOffsetsV;
variant_pack[offset_o] = devOffsetsO;
}
if (is_dropout) {
......@@ -715,8 +741,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *seq_offsets_q,
const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *seq_offsets_o,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -744,10 +769,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr;
if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
......@@ -801,9 +823,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV,
devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -825,8 +846,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -866,10 +886,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr;
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
......@@ -881,9 +898,9 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO,
devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsetsQ,
devPtrSeqOffsetsK, devPtrSeqOffsetsV, devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets,
devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream,
handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -905,9 +922,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle) {
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
......@@ -936,10 +952,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
......@@ -993,9 +1007,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV,
devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1019,9 +1032,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle) {
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
......@@ -1060,10 +1072,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
......@@ -1076,8 +1086,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO,
devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, devPtrSeqOffsetsO,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1094,15 +1104,17 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
}
}
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
......@@ -1122,10 +1134,8 @@ void fused_attn_arbitrary_seqlen_fwd(
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
......@@ -1179,9 +1189,8 @@ void fused_attn_arbitrary_seqlen_fwd(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV,
devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1205,9 +1214,9 @@ void fused_attn_arbitrary_seqlen_bwd(
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
......@@ -1234,10 +1243,8 @@ void fused_attn_arbitrary_seqlen_bwd(
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
......@@ -1250,8 +1257,8 @@ void fused_attn_arbitrary_seqlen_bwd(
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO,
devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, devPtrSeqOffsetsO,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......
......@@ -22,8 +22,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *seq_offsets_q,
const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *seq_offsets_o,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
......@@ -31,8 +30,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
......@@ -41,9 +39,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle);
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
......@@ -52,20 +49,21 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
......@@ -73,9 +71,9 @@ void fused_attn_arbitrary_seqlen_bwd(
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine
......
......@@ -360,6 +360,38 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu
kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid];
}
}
// convert cu_seqlens_padded to offsets
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h,
size_t hg, size_t d, int32_t *cu_seqlens_q_padded,
int32_t *cu_seqlens_kv_padded, int32_t *offsets_q,
int32_t *offsets_k, int32_t *offsets_v,
int32_t *offsets_o) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < b + 1) {
offsets_o[tid] = h * d * cu_seqlens_q_padded[tid];
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
offsets_q[tid] = h * d * cu_seqlens_q_padded[tid];
offsets_k[tid] = hg * d * cu_seqlens_kv_padded[tid];
offsets_v[tid] = offsets_k[tid];
break;
case NVTE_QKV_Layout_Group::NVTE_3HD:
case NVTE_QKV_Layout_Group::NVTE_H3D:
offsets_q[tid] = 3 * h * d * cu_seqlens_q_padded[tid];
offsets_k[tid] = offsets_q[tid];
offsets_v[tid] = offsets_q[tid];
break;
case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
offsets_q[tid] = h * d * cu_seqlens_q_padded[tid];
offsets_k[tid] = 2 * hg * d * cu_seqlens_kv_padded[tid];
offsets_v[tid] = offsets_k[tid];
break;
}
}
}
} // namespace fused_attn
// get cuDNN data type
......
......@@ -121,6 +121,11 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu
int32_t const *const kv_cu_seqlens, int32_t *q_seqlens,
int32_t *kv_seqlens);
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h,
size_t hg, size_t d, int32_t *cu_seqlens_q_padded,
int32_t *cu_seqlens_kv_padded, int32_t *offsets_q,
int32_t *offsets_k, int32_t *offsets_v,
int32_t *offsets_o);
} // namespace fused_attn
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
......
......@@ -166,21 +166,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
*
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* Tensor `cu_seqlens_padded` helps identify the correct offsets of different sequences
* in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
seq_offsets_q = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_k = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_v = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens
\endverbatim
* the offset tensor is not used in the attention calculation and can be set to empty `NVTETensor`.
* When the QKV format is `thd`, this tensor should follow the following rules.
* When there is no padding between sequences, the offset tensor should be equal to `cu_seqlens`,
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
* `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`.
*
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] Bias The Bias tensor.
......@@ -189,10 +183,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(seqlen_i) for i=0,...batch_size-1.
......@@ -207,11 +198,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
*/
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o, const NVTETensor rng_state,
size_t max_seqlen, bool is_training, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream);
......@@ -227,21 +216,15 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
*
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* Tensor `cu_seqlens_padded` helps identify the correct offsets of different sequences
* in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
seq_offsets_q = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_k = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_v = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens
\endverbatim
* the offset tensor is not used in the attention calculation and can be set to empty `NVTETensor`.
* When the QKV format is `thd`, this tensor should follow the following rules.
* When there is no padding between sequences, the offset tensor should be equal to `cu_seqlens`,
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
* `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`.
*
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] O The O tensor from forward.
......@@ -253,10 +236,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(seqlen_i) for i=0,...batch_size-1.
* \param[in] attn_scale Scaling factor for Q * K.T.
......@@ -267,13 +247,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S,
NVTETensor dP, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias,
const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream);
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with packed KV input.
*
......@@ -292,57 +273,49 @@ void nvte_fused_attn_bwd_qkvpacked(
*
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens_q
\endverbatim
* When there is no padding between sequences, the offset tensors should be equal to
* `cu_seqlens_q` and `cu_seqlens_kv` respectively.
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
*
* \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in 2HD or H2D layouts.
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
* `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`.
*
* \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in 2HD or H2D layouts.
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias,
NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o,
const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream);
......@@ -357,59 +330,52 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
*
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens_q
\endverbatim
* When there is no padding between sequences, the offset tensors should be equal to
* `cu_seqlens_q` and `cu_seqlens_kv` respectively.
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
*
* \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in H2D or 2HD layouts.
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode,
* e.g. M, ZInv, rng_state.
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
* `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`.
*
* \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in H2D or 2HD layouts.
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode,
* e.g. M, ZInv, rng_state.
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream);
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V.
*
......@@ -431,66 +397,48 @@ void nvte_fused_attn_bwd_kvpacked(
*
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
qkv_group = nvte_get_qkv_layout_group(qkv_layout)
if qkv_group == 'hd_hd_hd':
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv
if qkv_group in ['3hd', 'h3d']:
seq_offsets_q = num_attn_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_k = num_attn_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_v = num_attn_heads * head_dim * 3 * cu_seqlens_q
if qkv_group in ['hd_2hd', 'hd_h2d']:
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens_q
\endverbatim
* When there is no padding between sequences, the offset tensors should be equal to
* `cu_seqlens_q` and `cu_seqlens_kv` respectively.
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
*
* \param[in] Q The Q tensor.
* \param[in] K The K tensor.
* \param[in] V The V tensor.
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
* `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`.
*
* \param[in] Q The Q tensor.
* \param[in] K The K tensor.
* \param[in] V The V tensor.
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o, const NVTETensor rng_state,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
......@@ -510,73 +458,56 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
*
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
qkv_group = nvte_get_qkv_layout_group(qkv_layout)
if qkv_group == 'hd_hd_hd':
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv
if qkv_group in ['3hd', 'h3d']:
seq_offsets_q = num_attn_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_k = num_attn_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_v = num_attn_heads * head_dim * 3 * cu_seqlens_q
if qkv_group in ['hd_2hd', 'hd_h2d']:
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens_q
\endverbatim
* When there is no padding between sequences, the offset tensors should be equal to
* `cu_seqlens_q` and `cu_seqlens_kv` respectively.
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
*
* \param[in] Q The Q tensor.
* \param[in] K The K tensor.
* \param[in] V The V tensor.
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode,
* e.g. M, ZInv, rng_state.
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dK The gradient of the K tensor.
* \param[out] dV The gradient of the V tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
* `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`.
*
* \param[in] Q The Q tensor.
* \param[in] K The K tensor.
* \param[in] V The V tensor.
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode,
* e.g. M, ZInv, rng_state.
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dK The gradient of the K tensor.
* \param[out] dV The gradient of the V tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream);
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -143,19 +143,17 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
TensorWrapper query_workspace_tensor;
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
assert(q_max_seqlen == kv_max_seqlen);
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, is_training, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
......@@ -164,7 +162,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
......@@ -208,15 +205,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
......@@ -230,7 +225,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
......@@ -251,7 +245,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
......@@ -340,10 +333,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), rng_state_tensor.data(), q_max_seqlen,
descriptor.is_training, descriptor.scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, descriptor.scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -355,7 +346,6 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
......@@ -373,7 +363,6 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
......@@ -426,7 +415,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
......@@ -507,15 +495,13 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto dqkv = buffers[10];
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -533,7 +519,6 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
......@@ -559,7 +544,6 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
......
......@@ -640,12 +640,11 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(), QKV.stream());
nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
......@@ -655,12 +654,11 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(), QKV.stream());
nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -727,24 +725,22 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
QKV.stream());
nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
QKV.stream());
nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -818,12 +814,12 @@ void te_fused_attn_fwd_kvpacked(
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
......@@ -833,12 +829,12 @@ void te_fused_attn_fwd_kvpacked(
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -916,7 +912,6 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
......@@ -929,7 +924,6 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
......@@ -1001,10 +995,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(),
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale,
p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
......@@ -1018,10 +1011,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(),
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale,
p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -1100,10 +1092,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
......@@ -1113,10 +1104,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......
......@@ -550,10 +550,8 @@ class AttnFuncWithCP(torch.autograd.Function):
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_global_ranks,
......@@ -694,10 +692,8 @@ class AttnFuncWithCP(torch.autograd.Function):
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i % 2],
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
)
if len(rest) > 0:
......@@ -769,14 +765,12 @@ class AttnFuncWithCP(torch.autograd.Function):
attn_mask_type="padding" if padding else "no_mask",
attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i % 2],
seq_offsets_q=seq_offsets_q,
seq_offsets_k=(
None if seq_offsets_k is None else seq_offsets_k // 2
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=(
None
if cu_seqlens_kv_padded is None
else cu_seqlens_kv_padded // 2
),
seq_offsets_v=(
None if seq_offsets_v is None else seq_offsets_v // 2
),
seq_offsets_o=seq_offsets_o,
)
)
if len(rest) > 0:
......@@ -863,14 +857,12 @@ class AttnFuncWithCP(torch.autograd.Function):
attn_mask_type="padding" if padding else "no_mask",
attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i % 2],
seq_offsets_q=(
None if seq_offsets_q is None else seq_offsets_q // 2
),
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=(
None if seq_offsets_o is None else seq_offsets_o // 2
cu_seqlens_q_padded=(
None
if cu_seqlens_q_padded is None
else cu_seqlens_q_padded // 2
),
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
)
if len(rest) > 0:
......@@ -940,10 +932,8 @@ class AttnFuncWithCP(torch.autograd.Function):
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i % 2],
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
)
if len(rest) > 0:
......@@ -1082,10 +1072,8 @@ class AttnFuncWithCP(torch.autograd.Function):
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
*rng_states,
*attn_biases,
)
......@@ -1106,10 +1094,10 @@ class AttnFuncWithCP(torch.autograd.Function):
@staticmethod
def backward(ctx, dout):
(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) = ctx.saved_tensors[:6]
(seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o) = ctx.saved_tensors[6:10]
(cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[6:8]
cp_size = get_distributed_world_size(ctx.cp_group)
rng_states = ctx.saved_tensors[10 : 10 + cp_size]
attn_biases = ctx.saved_tensors[10 + cp_size : 10 + cp_size * 2]
rng_states = ctx.saved_tensors[8 : 8 + cp_size]
attn_biases = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2]
rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size]
......@@ -1224,10 +1212,8 @@ class AttnFuncWithCP(torch.autograd.Function):
TE_DType[kv.dtype],
aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout=qkv_layout,
......@@ -1305,10 +1291,10 @@ class AttnFuncWithCP(torch.autograd.Function):
TE_DType[kv.dtype],
aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
seq_offsets_q,
None if seq_offsets_k is None else seq_offsets_k // 2,
None if seq_offsets_v is None else seq_offsets_v // 2,
seq_offsets_o,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2
),
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout=qkv_layout,
......@@ -1392,10 +1378,10 @@ class AttnFuncWithCP(torch.autograd.Function):
TE_DType[kv.dtype],
aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
None if seq_offsets_q is None else seq_offsets_q // 2,
seq_offsets_k,
seq_offsets_v,
None if seq_offsets_o is None else seq_offsets_o // 2,
cu_seqlens_q_padded=(
None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2
),
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout=qkv_layout,
......@@ -1461,10 +1447,8 @@ class AttnFuncWithCP(torch.autograd.Function):
TE_DType[kv.dtype],
aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout=qkv_layout,
......@@ -1658,8 +1642,6 @@ class AttnFuncWithCP(torch.autograd.Function):
None,
None,
None,
None,
None,
attn_dbias,
None,
None,
......@@ -1675,10 +1657,8 @@ def attn_forward_func_with_cp(
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_global_ranks,
......@@ -1721,10 +1701,8 @@ def attn_forward_func_with_cp(
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_global_ranks,
......@@ -2593,8 +2571,6 @@ class FlashAttention(torch.nn.Module):
max_seqlen_kv,
None,
None,
None,
None,
self.attention_dropout if self.training else 0.0,
cp_group,
cp_global_ranks,
......@@ -2690,10 +2666,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
is_training,
max_seqlen,
cu_seqlens,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_padded,
qkv,
qkv_dtype,
attn_bias,
......@@ -2738,10 +2711,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fp8_dtype_forward,
fused_attention_backend,
attn_bias,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_padded,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
......@@ -2806,10 +2776,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
qkv_dtype,
fused_attention_backend,
attn_bias,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_padded,
None,
None,
None,
......@@ -2830,14 +2797,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
ctx.save_for_backward(
*qkvo_tensors,
cu_seqlens,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
*fp8_tensors,
*aux_ctx_tensors,
*qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors
)
ctx.fp8_meta = fp8_meta
ctx.max_seqlen = max_seqlen
......@@ -2870,10 +2830,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
qkv,
out,
cu_seqlens,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_padded,
qkv_fp8,
out_fp8,
fwd_scales,
......@@ -2939,10 +2896,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fp8_dtype_backward,
aux_ctx_tensors,
ctx.fused_attention_backend,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_padded,
fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o,
......@@ -2994,10 +2948,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.qkv_dtype,
aux_ctx_tensors,
ctx.fused_attention_backend,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_padded,
None,
None,
None,
......@@ -3019,9 +2970,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
# if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]:
return (
None,
None,
None,
None,
None,
None,
......@@ -3045,9 +2993,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
)
# else, return (dqkv, dbias)
return (
None,
None,
None,
None,
None,
None,
......@@ -3082,10 +3027,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
max_seqlen_kv,
cu_seqlens_q,
cu_seqlens_kv,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
q,
kv,
qkv_dtype,
......@@ -3139,10 +3082,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fp8_dtype_forward,
fused_attention_backend,
attn_bias,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
......@@ -3214,10 +3155,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
qkv_dtype,
fused_attention_backend,
attn_bias,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
None,
None,
None,
......@@ -3241,10 +3180,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
*qkvo_tensors,
cu_seqlens_q,
cu_seqlens_kv,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
*fp8_tensors,
*aux_ctx_tensors,
)
......@@ -3282,10 +3219,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
out,
cu_seqlens_q,
cu_seqlens_kv,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
q_fp8,
kv_fp8,
out_fp8,
......@@ -3355,10 +3290,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fp8_dtype_backward,
aux_ctx_tensors,
ctx.fused_attention_backend,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o,
......@@ -3428,10 +3361,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.qkv_dtype,
aux_ctx_tensors,
ctx.fused_attention_backend,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
None,
None,
None,
......@@ -3460,8 +3391,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
None,
None,
None,
None,
None,
dq,
dkv,
None,
......@@ -3489,8 +3418,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
None,
None,
None,
None,
None,
dq,
dkv,
None,
......@@ -3522,10 +3449,8 @@ class FusedAttnFunc(torch.autograd.Function):
max_seqlen_kv,
cu_seqlens_q,
cu_seqlens_kv,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
q,
k,
v,
......@@ -3602,10 +3527,8 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_dtype_forward,
fused_attention_backend,
attn_bias,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
......@@ -3727,10 +3650,8 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_dtype,
fused_attention_backend,
attn_bias,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
None,
None,
None,
......@@ -3763,10 +3684,8 @@ class FusedAttnFunc(torch.autograd.Function):
*qkvo_tensors,
cu_seqlens_q,
cu_seqlens_kv,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
*fp8_tensors,
*aux_ctx_tensors,
)
......@@ -3805,10 +3724,8 @@ class FusedAttnFunc(torch.autograd.Function):
out,
cu_seqlens_q,
cu_seqlens_kv,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
q_fp8,
k_fp8,
v_fp8,
......@@ -3882,10 +3799,8 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_dtype_backward,
aux_ctx_tensors,
ctx.fused_attention_backend,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o,
......@@ -3903,6 +3818,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.attn_bias_type,
ctx.attn_mask_type,
)
if ctx.fp8_meta["recipe"].fp8_mha:
dq = Float8Tensor(
data=dq_fp8,
......@@ -4007,10 +3923,8 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_dtype,
aux_ctx_tensors,
ctx.fused_attention_backend,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
None,
None,
None,
......@@ -4039,8 +3953,6 @@ class FusedAttnFunc(torch.autograd.Function):
None,
None,
None,
None,
None,
dq,
dk,
dv,
......@@ -4069,8 +3981,6 @@ class FusedAttnFunc(torch.autograd.Function):
None,
None,
None,
None,
None,
dq,
dk,
dv,
......@@ -4186,10 +4096,8 @@ class FusedAttention(torch.nn.Module):
qkv_layout: str = "sbh3d",
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
seq_offsets_q: Optional[torch.Tensor] = None,
seq_offsets_k: Optional[torch.Tensor] = None,
seq_offsets_v: Optional[torch.Tensor] = None,
seq_offsets_o: Optional[torch.Tensor] = None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
attn_mask_type: str = "causal",
......@@ -4271,31 +4179,9 @@ class FusedAttention(torch.nn.Module):
and cu_seqlens_q is not None
and cu_seqlens_kv is not None
), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
if (
seq_offsets_q is None
or seq_offsets_k is None
or seq_offsets_v is None
or seq_offsets_o is None
or context_parallel
):
qkv_group = "".join([x for x in qkv_layout if x not in "bst"])
qkv_group = "hd_hd_hd" if context_parallel else qkv_group
num_heads = query_layer.shape[-2]
num_gqa_groups = key_layer.shape[-2]
head_dim = query_layer.shape[-1]
seq_offsets_o = num_heads * head_dim * cu_seqlens_q
if qkv_group == "hd_hd_hd":
seq_offsets_q = num_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv
if qkv_group in ["3hd", "h3d"]:
seq_offsets_q = num_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_k = num_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_v = num_heads * head_dim * 3 * cu_seqlens_q
if qkv_group in ["hd_2hd", "hd_h2d"]:
seq_offsets_q = num_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None:
cu_seqlens_q_padded = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_kv
qkv_dtype = TE_DType[query_layer.dtype]
......@@ -4325,10 +4211,8 @@ class FusedAttention(torch.nn.Module):
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
self.attention_dropout if self.training else 0.0,
cp_group,
cp_global_ranks,
......@@ -4356,10 +4240,8 @@ class FusedAttention(torch.nn.Module):
max_seqlen_kv,
cu_seqlens_q,
cu_seqlens_kv,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
query_layer,
key_layer,
value_layer,
......@@ -4669,10 +4551,8 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_format: Optional[str] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
seq_offsets_q: Optional[torch.Tensor] = None,
seq_offsets_k: Optional[torch.Tensor] = None,
seq_offsets_v: Optional[torch.Tensor] = None,
seq_offsets_o: Optional[torch.Tensor] = None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
attn_mask_type: Optional[str] = None,
......@@ -4749,23 +4629,21 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_format: str, default = `None`
If provided, overrides :attr:`qkv_format` from initialization.
cu_seqlens_q: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths in a batch for `query_layer`,
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
with shape [batch_size + 1] and dtype torch.int32.
seq_offsets_q: Optional[torch.Tensor], default = `None`
Cumulative offset of different sequences in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
seq_offsets_k: Optional[torch.Tensor], default = `None`
Cumulative offset of different sequences in a batch for `key_layer`,
with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
seq_offsets_v: Optional[torch.Tensor], default = `None`
Cumulative offset of different sequences in a batch for `value_layer`,
with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
seq_offsets_o: Optional[torch.Tensor], default = `None`
Cumulative offset of different sequences in a batch for forward output,
with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for
`query_layer`, with shape [batch_size + 1] and dtype torch.int32.
When there is no padding between sequences in a batch,
`cu_seqlens_q_padded = cu_seqlens_q`.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
When there is no padding between sequences in a batch,
`cu_seqlens_kv_padded = cu_seqlens_kv`.
max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided.
......@@ -4992,9 +4870,25 @@ class DotProductAttention(TransformerEngineBaseModule):
# certain asserts before executing the forward pass.
# Filter: QKV layout.
if use_unfused_attention and qkv_format == "thd":
self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
if qkv_format == "thd":
if use_unfused_attention:
self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
if use_fused_attention and (
(
cu_seqlens_q_padded is not None
and torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
)
or (
cu_seqlens_kv_padded is not None
and torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
)
):
self.logger.debug(
"Disabling FlashAttention for qkv_format = thd "
"when there is padding between sequences."
)
use_flash_attention = False
# Filter: ONNX export.
if is_in_onnx_export_mode():
......@@ -5354,10 +5248,8 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type,
......@@ -5379,10 +5271,8 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type,
......
......@@ -85,10 +85,7 @@ def fused_attn_fwd_qkvpacked(
qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
cu_seqlens_padded: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
......@@ -124,14 +121,8 @@ def fused_attn_fwd_qkvpacked(
attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
cu_seqlens_padded: torch.Tensor, default = None
cumulative sequence offsets for QKV; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -246,10 +237,7 @@ def fused_attn_fwd_qkvpacked(
cu_seqlens,
qkv,
qkv_dtype,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_padded,
d_scale_qkv,
d_scale_s,
q_scale_s,
......@@ -275,10 +263,7 @@ def fused_attn_bwd_qkvpacked(
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
cu_seqlens_padded: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
......@@ -322,14 +307,8 @@ def fused_attn_bwd_qkvpacked(
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
cu_seqlens_padded: torch.Tensor, default = None
cumulative sequence offsets for QKV; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -419,10 +398,7 @@ def fused_attn_bwd_qkvpacked(
qkv_dtype,
dqkv_dtype,
aux_ctx_tensors,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_padded,
d_scale_qkv,
d_scale_s,
d_scale_o,
......@@ -449,10 +425,8 @@ def fused_attn_fwd_kvpacked(
qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
cu_seqlens_q_padded: torch.Tensor = None,
cu_seqlens_kv_padded: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
......@@ -495,14 +469,10 @@ def fused_attn_fwd_kvpacked(
attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv
seq_offsets_q: torch.Tensor, default = None
cu_seqlens_q_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -621,10 +591,8 @@ def fused_attn_fwd_kvpacked(
q,
kv,
qkv_dtype,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
d_scale_qkv,
d_scale_s,
q_scale_s,
......@@ -653,10 +621,8 @@ def fused_attn_bwd_kvpacked(
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
cu_seqlens_q_padded: torch.Tensor = None,
cu_seqlens_kv_padded: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
......@@ -707,14 +673,10 @@ def fused_attn_bwd_kvpacked(
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None
cu_seqlens_q_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -811,10 +773,8 @@ def fused_attn_bwd_kvpacked(
qkv_dtype,
dqkv_dtype,
aux_ctx_tensors,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
d_scale_qkv,
d_scale_s,
d_scale_o,
......@@ -842,10 +802,8 @@ def fused_attn_fwd(
qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
cu_seqlens_q_padded: torch.Tensor = None,
cu_seqlens_kv_padded: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
......@@ -892,14 +850,10 @@ def fused_attn_fwd(
attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v
seq_offsets_q: torch.Tensor, default = None
cu_seqlens_q_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -1021,10 +975,8 @@ def fused_attn_fwd(
k,
v,
qkv_dtype,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
d_scale_qkv,
d_scale_s,
q_scale_s,
......@@ -1054,10 +1006,8 @@ def fused_attn_bwd(
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
cu_seqlens_q_padded: torch.Tensor = None,
cu_seqlens_kv_padded: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
......@@ -1111,14 +1061,10 @@ def fused_attn_bwd(
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None
cu_seqlens_q_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -1220,10 +1166,8 @@ def fused_attn_bwd(
qkv_dtype,
dqkv_dtype,
aux_ctx_tensors,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
d_scale_qkv,
d_scale_s,
d_scale_o,
......
......@@ -26,22 +26,19 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q, const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP,
......@@ -53,8 +50,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q, const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
......@@ -67,27 +64,27 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor KV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV);
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dQKV);
std::vector<at::Tensor> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor K, const at::Tensor V,
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
......@@ -95,14 +92,14 @@ std::vector<at::Tensor> fused_attn_bwd(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV);
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dQKV);
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
......
......@@ -83,13 +83,11 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q, const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) {
const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
using namespace transformer_engine;
auto qkv_sizes = QKV.sizes().vec();
......@@ -107,8 +105,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
auto O = torch::empty(o_shape, options);
// construct NVTE tensors
TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens;
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens, te_cu_seqlens_padded;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
auto h = q_shape[q_shape.size() - 2];
......@@ -150,27 +147,12 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape,
DType::kInt32, nullptr, nullptr, nullptr);
if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) &&
(seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q =
makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k =
makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v =
makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o =
makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape,
if (cu_seqlens_padded.has_value()) {
auto cu_seqlens_padded_sizes = cu_seqlens_padded.value().sizes().vec();
std::vector<size_t> cu_seqlens_padded_shape{cu_seqlens_padded_sizes.begin(),
cu_seqlens_padded_sizes.end()};
te_cu_seqlens_padded =
makeTransformerEngineTensor(cu_seqlens_padded.value().data_ptr(), cu_seqlens_padded_shape,
DType::kInt32, nullptr, nullptr, nullptr);
}
......@@ -191,12 +173,11 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), te_seq_offsets_q.data(), te_seq_offsets_k.data(),
te_seq_offsets_v.data(), te_seq_offsets_o.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(),
at::cuda::getCurrentCUDAStream());
nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens.data(),
te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout, bias_type,
attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -232,12 +213,11 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
}
// execute the kernel
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), te_seq_offsets_q.data(), te_seq_offsets_k.data(),
te_seq_offsets_v.data(), te_seq_offsets_o.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(),
at::cuda::getCurrentCUDAStream());
nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens.data(),
te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout, bias_type,
attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -252,9 +232,8 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP,
......@@ -358,28 +337,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
TensorWrapper te_cu_seqlens = makeTransformerEngineTensor(
cu_seqlens.data_ptr(), cu_seqlens_shape, DType::kInt32, nullptr, nullptr, nullptr);
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) &&
(seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q =
makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k =
makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v =
makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o =
makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape,
TensorWrapper te_cu_seqlens_padded;
if (cu_seqlens_padded.has_value()) {
auto cu_seqlens_padded_sizes = cu_seqlens_padded.value().sizes().vec();
std::vector<size_t> cu_seqlens_padded_shape{cu_seqlens_padded_sizes.begin(),
cu_seqlens_padded_sizes.end()};
te_cu_seqlens_padded =
makeTransformerEngineTensor(cu_seqlens_padded.value().data_ptr(), cu_seqlens_padded_shape,
DType::kInt32, nullptr, nullptr, nullptr);
}
......@@ -387,12 +351,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_seq_offsets_q.data(),
te_seq_offsets_k.data(), te_seq_offsets_v.data(), te_seq_offsets_o.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(),
at::cuda::getCurrentCUDAStream());
nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_cu_seqlens.data(), te_cu_seqlens_padded.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -400,12 +363,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_seq_offsets_q.data(),
te_seq_offsets_k.data(), te_seq_offsets_v.data(), te_seq_offsets_o.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(),
at::cuda::getCurrentCUDAStream());
nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_cu_seqlens.data(), te_cu_seqlens_padded.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -419,8 +381,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q, const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
......@@ -440,7 +402,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
// construct NVTE tensors
TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv;
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
auto h = q_shape[q_shape.size() - 2];
......@@ -489,28 +451,19 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr);
if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) &&
(seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q =
makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k =
makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v =
makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o =
makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape,
DType::kInt32, nullptr, nullptr, nullptr);
if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) {
auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec();
std::vector<size_t> cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(),
cu_seqlens_q_padded_sizes.end()};
auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec();
std::vector<size_t> cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(),
cu_seqlens_kv_padded_sizes.end()};
te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(),
cu_seqlens_q_padded_shape, DType::kInt32,
nullptr, nullptr, nullptr);
te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(),
cu_seqlens_kv_padded_shape, DType::kInt32,
nullptr, nullptr, nullptr);
}
// extract rng seed and offset
......@@ -532,10 +485,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_seq_offsets_q.data(),
te_seq_offsets_k.data(), te_seq_offsets_v.data(), te_seq_offsets_o.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(),
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -573,10 +526,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
// execute the kernel
nvte_fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_seq_offsets_q.data(),
te_seq_offsets_k.data(), te_seq_offsets_v.data(), te_seq_offsets_o.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -592,14 +545,14 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor KV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV) {
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dQKV) {
using namespace transformer_engine;
auto q_sizes = Q.sizes().vec();
......@@ -689,29 +642,20 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr);
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) &&
(seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q =
makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k =
makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v =
makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o =
makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape,
DType::kInt32, nullptr, nullptr, nullptr);
TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded;
if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) {
auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec();
std::vector<size_t> cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(),
cu_seqlens_q_padded_sizes.end()};
auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec();
std::vector<size_t> cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(),
cu_seqlens_kv_padded_sizes.end()};
te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(),
cu_seqlens_q_padded_shape, DType::kInt32,
nullptr, nullptr, nullptr);
te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(),
cu_seqlens_kv_padded_shape, DType::kInt32,
nullptr, nullptr, nullptr);
}
// convert auxiliary tensors from forward to NVTETensors
......@@ -746,13 +690,12 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(), te_seq_offsets_k.data(),
te_seq_offsets_v.data(), te_seq_offsets_o.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type,
attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -760,13 +703,12 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(), te_seq_offsets_k.data(),
te_seq_offsets_v.data(), te_seq_offsets_o.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type,
attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -780,13 +722,13 @@ std::vector<at::Tensor> fused_attn_fwd(
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor K, const at::Tensor V,
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) {
using namespace transformer_engine;
auto q_sizes = Q.sizes().vec();
......@@ -802,7 +744,7 @@ std::vector<at::Tensor> fused_attn_fwd(
// construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias;
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
auto h = q_shape[q_shape.size() - 2];
......@@ -853,28 +795,19 @@ std::vector<at::Tensor> fused_attn_fwd(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr);
if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) &&
(seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q =
makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k =
makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v =
makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o =
makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape,
DType::kInt32, nullptr, nullptr, nullptr);
if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) {
auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec();
std::vector<size_t> cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(),
cu_seqlens_q_padded_sizes.end()};
auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec();
std::vector<size_t> cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(),
cu_seqlens_kv_padded_sizes.end()};
te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(),
cu_seqlens_q_padded_shape, DType::kInt32,
nullptr, nullptr, nullptr);
te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(),
cu_seqlens_kv_padded_shape, DType::kInt32,
nullptr, nullptr, nullptr);
}
// extract rng seed and offset
......@@ -897,11 +830,10 @@ std::vector<at::Tensor> fused_attn_fwd(
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_seq_offsets_q.data(), te_seq_offsets_k.data(),
te_seq_offsets_v.data(), te_seq_offsets_o.data(), te_rng_state.data(),
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout,
bias_type, attn_mask_type, workspace.data(),
at::cuda::getCurrentCUDAStream());
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type,
attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -939,11 +871,10 @@ std::vector<at::Tensor> fused_attn_fwd(
// execute the kernel
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_seq_offsets_q.data(), te_seq_offsets_k.data(),
te_seq_offsets_v.data(), te_seq_offsets_o.data(), te_rng_state.data(),
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout,
bias_type, attn_mask_type, workspace.data(),
at::cuda::getCurrentCUDAStream());
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type,
attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -959,14 +890,14 @@ std::vector<at::Tensor> fused_attn_bwd(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV) {
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dQKV) {
using namespace transformer_engine;
auto q_sizes = Q.sizes().vec();
......@@ -1131,29 +1062,20 @@ std::vector<at::Tensor> fused_attn_bwd(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr);
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) &&
(seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q =
makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k =
makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v =
makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o =
makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape,
DType::kInt32, nullptr, nullptr, nullptr);
TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded;
if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) {
auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec();
std::vector<size_t> cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(),
cu_seqlens_q_padded_sizes.end()};
auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec();
std::vector<size_t> cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(),
cu_seqlens_kv_padded_sizes.end()};
te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(),
cu_seqlens_q_padded_shape, DType::kInt32,
nullptr, nullptr, nullptr);
te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(),
cu_seqlens_kv_padded_shape, DType::kInt32,
nullptr, nullptr, nullptr);
}
// convert auxiliary tensors from forward to NVTETensors
......@@ -1191,10 +1113,9 @@ std::vector<at::Tensor> fused_attn_bwd(
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(),
te_seq_offsets_o.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type, workspace.data(),
at::cuda::getCurrentCUDAStream());
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -1205,10 +1126,9 @@ std::vector<at::Tensor> fused_attn_bwd(
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(),
te_seq_offsets_o.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type, workspace.data(),
at::cuda::getCurrentCUDAStream());
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment