Unverified Commit 70d3251f authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[C/PyTorch] Simplify THD offset tensors (#927)



* simplify offset tensors
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes; tests pass
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix C lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace with_offset with with_padding
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace with_padding with padded
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes after merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix for fused attn fwd/bwd calls
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix Jax
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* adjust spacing in docstring
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix pytorch tests; fix paddle api
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix attn_biases
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix AttnFuncWithCP backward
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix attn with CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix paddle
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 94a426b0
...@@ -832,24 +832,6 @@ def _run_dot_product_attention( ...@@ -832,24 +832,6 @@ def _run_dot_product_attention(
inp[i].requires_grad = True inp[i].requires_grad = True
inp_orig[i].requires_grad = True inp_orig[i].requires_grad = True
# Create ragged offsets for q/k/v
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = None, None, None, None
qkv_group = "".join([x for x in qkv_layout if x not in "bst"])
if qkv_format == "thd":
seq_offsets_o = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
if qkv_group == "hd_hd_hd":
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
if qkv_group in ["3hd", "h3d"]:
seq_offsets_q = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_k = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_v = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
if qkv_group in ["hd_2hd", "hd_h2d"]:
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
# Create output gradient # Create output gradient
qkv_format_kv = "_".join(qkv_format) qkv_format_kv = "_".join(qkv_format)
qkv_format_kv = qkv_format_kv.replace("s", "sq") qkv_format_kv = qkv_format_kv.replace("s", "sq")
...@@ -928,10 +910,8 @@ def _run_dot_product_attention( ...@@ -928,10 +910,8 @@ def _run_dot_product_attention(
max_seqlen_kv=config.max_seqlen_kv, max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q, cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
seq_offsets_k=seq_offsets_k, cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
attn_mask_type=config.attn_mask_type, attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn, checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
...@@ -1957,8 +1937,6 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -1957,8 +1937,6 @@ class _custom_mha_fp8(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale[META_S],
...@@ -2038,8 +2016,6 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -2038,8 +2016,6 @@ class _custom_mha_fp8(torch.autograd.Function):
FusedAttnBackend["FP8"], FusedAttnBackend["FP8"],
None, None,
None, None,
None,
None,
fwd_scale_inverses[META_QKV], # d_scale_qkv, fwd_scale_inverses[META_QKV], # d_scale_qkv,
fwd_scale_inverses[META_S], # d_scale_s, fwd_scale_inverses[META_S], # d_scale_s,
fwd_scale_inverses[META_O], # d_scale_o, fwd_scale_inverses[META_O], # d_scale_o,
......
...@@ -8,9 +8,11 @@ import subprocess ...@@ -8,9 +8,11 @@ import subprocess
from test_fused_attn import ( from test_fused_attn import (
ModelConfig, ModelConfig,
_is_flash_attention_2_available, _is_flash_attention_2_available,
_cudnn_version,
) )
from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import (
get_device_compute_capability,
get_cudnn_version,
)
model_configs_flash_attn = { model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
...@@ -58,7 +60,7 @@ model_configs_fused_attn = { ...@@ -58,7 +60,7 @@ model_configs_fused_attn = {
} }
@pytest.mark.skipif(_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("model", model_configs_fused_attn.keys())
......
...@@ -196,21 +196,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -196,21 +196,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// NVTE fused attention FWD with packed QKV // NVTE fused attention FWD with packed QKV
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, const NVTETensor rng_state, size_t max_seqlen, bool is_training,
const NVTETensor seq_offsets_o, const NVTETensor rng_state, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
size_t max_seqlen, bool is_training, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) { NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens); const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q); const Tensor *input_cu_seqlens_padded = reinterpret_cast<const Tensor *>(cu_seqlens_padded);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV); const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV);
const Tensor *input_Bias = reinterpret_cast<const Tensor *>(Bias); const Tensor *input_Bias = reinterpret_cast<const Tensor *>(Bias);
...@@ -252,8 +247,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -252,8 +247,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type, b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, attn_mask_type, input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -272,21 +266,19 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -272,21 +266,19 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
} }
} }
// NVTE fused attention BWD with packed QKV // NVTE fused attention BWD with packed QKV
void nvte_fused_attn_bwd_qkvpacked( void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, const NVTETensor S, NVTETensor dP,
NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o, size_t max_seqlen, const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) { NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens); const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q); const Tensor *input_cu_seqlens_padded = reinterpret_cast<const Tensor *>(cu_seqlens_padded);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV); const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O); const Tensor *input_O = reinterpret_cast<const Tensor *>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor *>(dO); const Tensor *input_dO = reinterpret_cast<const Tensor *>(dO);
...@@ -338,8 +330,7 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -338,8 +330,7 @@ void nvte_fused_attn_bwd_qkvpacked(
fused_attn_arbitrary_seqlen_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV,
input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
input_rng_state, wkspace, stream, handle);
#else #else
const char *err_msg = const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention " "cuDNN 8.9.0 is required for BF16/FP16 fused attention "
...@@ -366,21 +357,18 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -366,21 +357,18 @@ void nvte_fused_attn_bwd_qkvpacked(
void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor cu_seqlens_q_padded,
const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o, const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
size_t max_seqlen_kv, bool is_training, float attn_scale, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) { NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q); const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k); const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV); const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV);
...@@ -426,8 +414,8 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const ...@@ -426,8 +414,8 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_seq_offsets_q, input_seq_offsets_k, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, stream, handle); input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -450,18 +438,16 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -450,18 +438,16 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) { NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q); const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k); const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV); const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O); const Tensor *input_O = reinterpret_cast<const Tensor *>(O);
...@@ -519,9 +505,9 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -519,9 +505,9 @@ void nvte_fused_attn_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, attn_mask_type, input_Q, input_KV, input_O, input_dO, input_Bias, output_S, output_dQ,
output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_seq_offsets_q, output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
stream, handle); handle);
#else #else
const char *err_msg = const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention " "cuDNN 8.9.3 is required for BF16/FP16 fused attention "
...@@ -549,9 +535,8 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -549,9 +535,8 @@ void nvte_fused_attn_bwd_kvpacked(
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor seq_offsets_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
const NVTETensor seq_offsets_o, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
...@@ -560,10 +545,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -560,10 +545,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q); const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k); const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K); const Tensor *input_K = reinterpret_cast<const Tensor *>(K);
...@@ -601,8 +584,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -601,8 +584,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, bias_type, attn_mask_type, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_seq_offsets_q, input_seq_offsets_k, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, stream, handle); input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -625,20 +608,17 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -625,20 +608,17 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor seq_offsets_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_kv, float attn_scale, float dropout,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) {
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd); NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q); const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k); const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K); const Tensor *input_K = reinterpret_cast<const Tensor *>(K);
const Tensor *input_V = reinterpret_cast<const Tensor *>(V); const Tensor *input_V = reinterpret_cast<const Tensor *>(V);
...@@ -690,8 +670,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -690,8 +670,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S,
output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
input_rng_state, wkspace, stream, handle); handle);
#else #else
const char *err_msg = const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention " "cuDNN 8.9.0 is required for BF16/FP16 fused attention "
......
...@@ -22,8 +22,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -22,8 +22,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *seq_offsets_q, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *seq_offsets_o,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
...@@ -31,8 +30,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( ...@@ -31,8 +30,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd_kvpacked(
...@@ -41,9 +39,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -41,9 +39,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
...@@ -52,20 +49,21 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -52,20 +49,21 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd( void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
...@@ -73,9 +71,9 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -73,9 +71,9 @@ void fused_attn_arbitrary_seqlen_bwd(
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900 #endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -360,6 +360,38 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu ...@@ -360,6 +360,38 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu
kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid]; kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid];
} }
} }
// convert cu_seqlens_padded to offsets
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h,
size_t hg, size_t d, int32_t *cu_seqlens_q_padded,
int32_t *cu_seqlens_kv_padded, int32_t *offsets_q,
int32_t *offsets_k, int32_t *offsets_v,
int32_t *offsets_o) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < b + 1) {
offsets_o[tid] = h * d * cu_seqlens_q_padded[tid];
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
offsets_q[tid] = h * d * cu_seqlens_q_padded[tid];
offsets_k[tid] = hg * d * cu_seqlens_kv_padded[tid];
offsets_v[tid] = offsets_k[tid];
break;
case NVTE_QKV_Layout_Group::NVTE_3HD:
case NVTE_QKV_Layout_Group::NVTE_H3D:
offsets_q[tid] = 3 * h * d * cu_seqlens_q_padded[tid];
offsets_k[tid] = offsets_q[tid];
offsets_v[tid] = offsets_q[tid];
break;
case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
offsets_q[tid] = h * d * cu_seqlens_q_padded[tid];
offsets_k[tid] = 2 * hg * d * cu_seqlens_kv_padded[tid];
offsets_v[tid] = offsets_k[tid];
break;
}
}
}
} // namespace fused_attn } // namespace fused_attn
// get cuDNN data type // get cuDNN data type
......
...@@ -121,6 +121,11 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu ...@@ -121,6 +121,11 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu
int32_t const *const kv_cu_seqlens, int32_t *q_seqlens, int32_t const *const kv_cu_seqlens, int32_t *q_seqlens,
int32_t *kv_seqlens); int32_t *kv_seqlens);
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h,
size_t hg, size_t d, int32_t *cu_seqlens_q_padded,
int32_t *cu_seqlens_kv_padded, int32_t *offsets_q,
int32_t *offsets_k, int32_t *offsets_v,
int32_t *offsets_o);
} // namespace fused_attn } // namespace fused_attn
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
......
...@@ -143,19 +143,17 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -143,19 +143,17 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
assert(q_max_seqlen == kv_max_seqlen); assert(q_max_seqlen == kv_max_seqlen);
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, is_training, scaling_factor, dropout_probability,
dummy_ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, nullptr);
query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr); nullptr);
...@@ -164,7 +162,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -164,7 +162,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
s_tensor.data(), o_tensor.data(), &aux_output_tensors, s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr); query_workspace_tensor.data(), nullptr);
...@@ -208,15 +205,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -208,15 +205,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, bias_type, mask_type, query_workspace_tensor.data(), nullptr);
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
...@@ -230,7 +225,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -230,7 +225,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr); nullptr);
...@@ -251,7 +245,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -251,7 +245,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr); query_workspace_tensor.data(), nullptr);
...@@ -340,10 +333,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -340,10 +333,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, descriptor.scaling_factor,
dummy_ragged_offset_tensor.data(), rng_state_tensor.data(), q_max_seqlen, dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
descriptor.is_training, descriptor.scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -355,7 +346,6 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -355,7 +346,6 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream); workspace_tensor.data(), stream);
...@@ -373,7 +363,6 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -373,7 +363,6 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
s_tensor.data(), o_tensor.data(), &aux_output_tensors, s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, descriptor.is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream); bias_type, mask_type, workspace_tensor.data(), stream);
...@@ -426,7 +415,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -426,7 +415,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr); bias_type, mask_type, query_workspace_tensor.data(), nullptr);
...@@ -507,15 +495,13 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -507,15 +495,13 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto dqkv = buffers[10]; auto dqkv = buffers[10];
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, bias_type, mask_type, workspace_tensor.data(), stream);
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -533,7 +519,6 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -533,7 +519,6 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream); dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
...@@ -559,7 +544,6 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -559,7 +544,6 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream); workspace_tensor.data(), stream);
......
...@@ -640,12 +640,11 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -640,12 +640,11 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32); auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(),
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, &nvte_aux_tensor_pack, te_cu_seqlens.data(),
te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
workspace.data(), QKV.stream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
...@@ -655,12 +654,11 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -655,12 +654,11 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
output_s->data.dptr = GetOptionalDataPtr(softmax_aux); output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel // execute the kernel
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(),
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, &nvte_aux_tensor_pack, te_cu_seqlens.data(),
te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
workspace.data(), QKV.stream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -727,24 +725,22 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -727,24 +725,22 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32); auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, &nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), attn_mask_type_enum, workspace.data(), QKV.stream());
QKV.stream());
// allocate memory for workspace // allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel // execute kernel
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, &nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), attn_mask_type_enum, workspace.data(), QKV.stream());
QKV.stream());
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -818,12 +814,12 @@ void te_fused_attn_fwd_kvpacked( ...@@ -818,12 +814,12 @@ void te_fused_attn_fwd_kvpacked(
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32); auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(),
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
...@@ -833,12 +829,12 @@ void te_fused_attn_fwd_kvpacked( ...@@ -833,12 +829,12 @@ void te_fused_attn_fwd_kvpacked(
output_s->data.dptr = GetOptionalDataPtr(softmax_aux); output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel // execute the kernel
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(),
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -916,7 +912,6 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K ...@@ -916,7 +912,6 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
...@@ -929,7 +924,6 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K ...@@ -929,7 +924,6 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
...@@ -1001,10 +995,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1001,10 +995,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale,
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), workspace.data(), Q.stream());
Q.stream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
...@@ -1018,10 +1011,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1018,10 +1011,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale,
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), workspace.data(), Q.stream());
Q.stream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -1100,10 +1092,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1100,10 +1092,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), attn_mask_type_enum, workspace.data(), Q.stream());
Q.stream());
// allocate memory for workspace // allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
...@@ -1113,10 +1104,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1113,10 +1104,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), attn_mask_type_enum, workspace.data(), Q.stream());
Q.stream());
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......
This diff is collapsed.
...@@ -85,10 +85,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -85,10 +85,7 @@ def fused_attn_fwd_qkvpacked(
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None, cu_seqlens_padded: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
...@@ -124,14 +121,8 @@ def fused_attn_fwd_qkvpacked( ...@@ -124,14 +121,8 @@ def fused_attn_fwd_qkvpacked(
attn_bias: torch.Tensor, default = None attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv
seq_offsets_q: torch.Tensor, default = None cu_seqlens_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for QKV; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -246,10 +237,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -246,10 +237,7 @@ def fused_attn_fwd_qkvpacked(
cu_seqlens, cu_seqlens,
qkv, qkv,
qkv_dtype, qkv_dtype,
seq_offsets_q, cu_seqlens_padded,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
d_scale_qkv, d_scale_qkv,
d_scale_s, d_scale_s,
q_scale_s, q_scale_s,
...@@ -275,10 +263,7 @@ def fused_attn_bwd_qkvpacked( ...@@ -275,10 +263,7 @@ def fused_attn_bwd_qkvpacked(
dqkv_dtype: tex.DType, dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None, cu_seqlens_padded: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
...@@ -322,14 +307,8 @@ def fused_attn_bwd_qkvpacked( ...@@ -322,14 +307,8 @@ def fused_attn_bwd_qkvpacked(
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends. please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None cu_seqlens_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for QKV; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -419,10 +398,7 @@ def fused_attn_bwd_qkvpacked( ...@@ -419,10 +398,7 @@ def fused_attn_bwd_qkvpacked(
qkv_dtype, qkv_dtype,
dqkv_dtype, dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
seq_offsets_q, cu_seqlens_padded,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
d_scale_qkv, d_scale_qkv,
d_scale_s, d_scale_s,
d_scale_o, d_scale_o,
...@@ -449,10 +425,8 @@ def fused_attn_fwd_kvpacked( ...@@ -449,10 +425,8 @@ def fused_attn_fwd_kvpacked(
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
...@@ -495,14 +469,10 @@ def fused_attn_fwd_kvpacked( ...@@ -495,14 +469,10 @@ def fused_attn_fwd_kvpacked(
attn_bias: torch.Tensor, default = None attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv
seq_offsets_q: torch.Tensor, default = None cu_seqlens_q_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1] cumulative sequence offsets for KV; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -621,10 +591,8 @@ def fused_attn_fwd_kvpacked( ...@@ -621,10 +591,8 @@ def fused_attn_fwd_kvpacked(
q, q,
kv, kv,
qkv_dtype, qkv_dtype,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
d_scale_qkv, d_scale_qkv,
d_scale_s, d_scale_s,
q_scale_s, q_scale_s,
...@@ -653,10 +621,8 @@ def fused_attn_bwd_kvpacked( ...@@ -653,10 +621,8 @@ def fused_attn_bwd_kvpacked(
dqkv_dtype: tex.DType, dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
...@@ -707,14 +673,10 @@ def fused_attn_bwd_kvpacked( ...@@ -707,14 +673,10 @@ def fused_attn_bwd_kvpacked(
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends. please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None cu_seqlens_q_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1] cumulative sequence offsets for KV; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -811,10 +773,8 @@ def fused_attn_bwd_kvpacked( ...@@ -811,10 +773,8 @@ def fused_attn_bwd_kvpacked(
qkv_dtype, qkv_dtype,
dqkv_dtype, dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
d_scale_qkv, d_scale_qkv,
d_scale_s, d_scale_s,
d_scale_o, d_scale_o,
...@@ -842,10 +802,8 @@ def fused_attn_fwd( ...@@ -842,10 +802,8 @@ def fused_attn_fwd(
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
...@@ -892,14 +850,10 @@ def fused_attn_fwd( ...@@ -892,14 +850,10 @@ def fused_attn_fwd(
attn_bias: torch.Tensor, default = None attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v
seq_offsets_q: torch.Tensor, default = None cu_seqlens_q_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1] cumulative sequence offsets for KV; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -1021,10 +975,8 @@ def fused_attn_fwd( ...@@ -1021,10 +975,8 @@ def fused_attn_fwd(
k, k,
v, v,
qkv_dtype, qkv_dtype,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
d_scale_qkv, d_scale_qkv,
d_scale_s, d_scale_s,
q_scale_s, q_scale_s,
...@@ -1054,10 +1006,8 @@ def fused_attn_bwd( ...@@ -1054,10 +1006,8 @@ def fused_attn_bwd(
dqkv_dtype: tex.DType, dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
...@@ -1111,14 +1061,10 @@ def fused_attn_bwd( ...@@ -1111,14 +1061,10 @@ def fused_attn_bwd(
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends. please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None cu_seqlens_q_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1] cumulative sequence offsets for KV; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -1220,10 +1166,8 @@ def fused_attn_bwd( ...@@ -1220,10 +1166,8 @@ def fused_attn_bwd(
qkv_dtype, qkv_dtype,
dqkv_dtype, dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
d_scale_qkv, d_scale_qkv,
d_scale_s, d_scale_s,
d_scale_o, d_scale_o,
......
...@@ -26,22 +26,19 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -26,22 +26,19 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q, const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O, const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked( std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor O, const at::Tensor dO, const at::Tensor QKV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP,
...@@ -53,8 +50,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -53,8 +50,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q, const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O, c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
...@@ -67,27 +64,27 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -67,27 +64,27 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor KV, const at::Tensor O, const at::Tensor dO, const at::Tensor KV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dQKV, c10::optional<at::Tensor> amax_dP, const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dQKV); c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dQKV);
std::vector<at::Tensor> fused_attn_fwd( std::vector<at::Tensor> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor K, const at::Tensor V, const at::Tensor Q, const at::Tensor K, const at::Tensor V,
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> seq_offsets_q, const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias, const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread); size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd( std::vector<at::Tensor> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
...@@ -95,14 +92,14 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -95,14 +92,14 @@ std::vector<at::Tensor> fused_attn_bwd(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO, const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dQKV, c10::optional<at::Tensor> amax_dP, const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dQKV); c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dQKV);
at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment