Unverified Commit e9606077 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[C/PyTorch] Add THD support for cuDNN attention (#832)



* add THD support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add seq_offsets_o and use new offset calculation
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* addition to previous commit; fix unit test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add None for offset_o gradient
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: test padding between sequences
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: fix tests for padding between sequences
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix tests for sbhd/bshd layouts; clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update cudnn-frontend and add tests for max_seqlen_q=1 and d=256 for inference
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test sbhd/bshd layouts for sq1, d256 inference case
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace wording from accumulative to cumulative
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add offset tensors to custom fp8 mha tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add version control for cuDNN
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add sm>=90 constraint for thd support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix cuDNN support for sq=1, d=256
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint and minor tweak for fp8 tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* modify cudnn version and restrict MQA/GQA support for THD
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add notes for seq offset tensors
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add dummy tensor to pass jax build
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add dummy tensor to pass paddle build
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix Jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent e8a17d1e
Subproject commit 1b0b5eac540b7f8fd19b18f1e6b8427c95503348
Subproject commit b740542818f36857acf7f9853f749bbad4118c65
......@@ -37,7 +37,7 @@ from transformer_engine.jax.softmax import SoftmaxType
is_fp8_supported, reason = is_fp8_available()
DATA_SHAPE = [(128, 32, 512), (512, 32, 512)]
DATA_SHAPE = [(32, 128, 512), (32, 512, 512)] # (B, S, H)
DTYPE = [jnp.float32, jnp.bfloat16]
ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]
......@@ -736,7 +736,7 @@ class TestDotProductAttn(TestLayer):
q_key, k_key, v_key = jax.random.split(key, 3)
b, s, *_ = shape
if self.attrs[DotProductAttnAttr.TRANSPOSE_BS]:
b, s = s, b
shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask
......@@ -786,6 +786,7 @@ class MultiHeadAttnAttr:
ZERO_CEN = 'zero_centered_gamma'
NUM_ATTN_HEADS = 'num_attention_heads'
NUM_GQA_GROUPS = 'num_gqa_groups'
TRANSPOSE_BS = 'transpose_batch_sequence'
ENABLE_ROPE = 'enable_rotary_pos_emb'
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
LORA_SCOPE = 'low_rank_adaptation_scope'
......@@ -795,42 +796,48 @@ class MultiHeadAttnAttr:
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding'
ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: True,
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding'
ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: False,
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding'
ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: True,
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal'
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: False,
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal'
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: True,
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal'
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: False,
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
......@@ -839,7 +846,8 @@ class MultiHeadAttnAttr:
ROPE_GROUP_METHOD: 'consecutive',
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal'
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: True,
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
......@@ -848,7 +856,8 @@ class MultiHeadAttnAttr:
ROPE_GROUP_METHOD: 'consecutive',
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal'
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: False,
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
......@@ -857,7 +866,8 @@ class MultiHeadAttnAttr:
ROPE_GROUP_METHOD: 'alternate',
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal'
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: True,
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
......@@ -865,7 +875,8 @@ class MultiHeadAttnAttr:
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding',
LORA_SCOPE: 'all'
LORA_SCOPE: 'all',
TRANSPOSE_BS: False,
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
......@@ -873,7 +884,8 @@ class MultiHeadAttnAttr:
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal',
LORA_SCOPE: 'all'
LORA_SCOPE: 'all',
TRANSPOSE_BS: True,
}]
......@@ -882,7 +894,9 @@ class TestMultiHeadAttn(TestLayer):
def input_getter(self, shape, dtype):
key = jax.random.PRNGKey(seed=1234)
q_key, kv_key = jax.random.split(key, 2)
s, b, *_ = shape
b, s, *_ = shape
if self.attrs[MultiHeadAttnAttr.TRANSPOSE_BS]:
shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask]
......@@ -906,7 +920,7 @@ class TestMultiHeadAttn(TestLayer):
rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none')
fuse_qkv_params = True
transpose_batch_sequence = True
transpose_batch_sequence = attrs[MultiHeadAttnAttr.TRANSPOSE_BS]
scale_attn_logits = False
scaled_query_init = True
float32_logits = False
......@@ -962,6 +976,7 @@ class TestMultiHeadAttn(TestLayer):
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
......@@ -977,7 +992,7 @@ class TestMultiHeadAttn(TestLayer):
fp8_format,
rtol=1e-05,
atol=1e-08):
self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
......@@ -1240,7 +1255,7 @@ class TestTransformer(TestLayer):
q_key, kv_key = jax.random.split(key, 2)
b, s, *_ = shape
if self.attrs[TransformerLayerAttr.TRANSPOSE_BS]:
b, s = s, b
shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask
......
......@@ -194,13 +194,17 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool:
return False
return True
def _is_unfused_attention_supported(config: ModelConfig) -> bool:
def _is_unfused_attention_supported(
config: ModelConfig,
qkv_format: str,
) -> bool:
"""Check if UnfusedDotProductAttention supports a model configuration"""
if ("padding" in config.attn_mask_type):
return False
if ("causal" in config.attn_mask_type and config.attn_type == 'cross'):
return False
if qkv_format == 'thd':
return False
return True
......@@ -210,6 +214,8 @@ model_configs_base = {
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
}
......@@ -239,7 +245,9 @@ def get_swa(seq_q, seq_kv, w=None):
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
@pytest.mark.parametrize("swa", [False])
def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa):
@pytest.mark.parametrize("pad_between_seqs", [False])
def test_dot_product_attention(dtype, model_configs, model, ckpt_attn,
workspace_opt, qkv_layout, swa, pad_between_seqs):
"""Test DotProductAttention module"""
# Get configs
......@@ -258,7 +266,8 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
)
# Skip if only unfused backend is supported
unfused_attn_supported = _is_unfused_attention_supported(config)
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format)
if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
......@@ -269,14 +278,19 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
flash_attn_supported = _is_flash_attention_supported(config)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
if (qkv_format == 'thd' and 'padding' not in config.attn_mask_type):
pytest.skip("THD layout requires padding/padding_causal mask type.")
# d=256 is supported by cuDNN 9.0+ for inference but not training
is_training = (config.head_dim <= 128)
# UnfusedDotProductAttention backend
if unfused_attn_supported:
if swa:
attn_mask_type = config.attn_mask_type
config.attn_mask_type = "arbitrary"
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype, config, "UnfusedDotProductAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
dtype, config, "UnfusedDotProductAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
if swa:
config.attn_mask_type = attn_mask_type
......@@ -285,22 +299,26 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
if fused_attn_supported:
if len(fused_attn_backend) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
if len(fused_attn_backend) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
if unfused_attn_supported and fused_attn_supported:
......@@ -335,7 +353,7 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False)
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
model_configs_mask = {
......@@ -361,7 +379,7 @@ model_configs_mask = {
@pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model):
"""Test DotProductAttention module with different mask types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
model_configs_bias = {
......@@ -399,7 +417,7 @@ model_configs_bias = {
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
model_configs_bias_shapes = {
......@@ -426,7 +444,8 @@ model_configs_bias_shapes = {
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
def test_dpa_bias_shapes(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types and shapes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
......@@ -443,7 +462,8 @@ model_configs_swa = {
@pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model):
"""Test DotProductAttention module with sliding window attention"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, True)
test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False)
model_configs_alibi_slopes = {
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
......@@ -460,14 +480,12 @@ model_configs_alibi_slopes = {
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
def test_dpa_alibi_slopes(dtype, model_configs, model):
"""Test DotProductAttention module with ALiBi slopes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
# will add tests for thd layouts later when the support is available in fused attention
#'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd',
]
......@@ -481,6 +499,8 @@ model_configs_layout = {
"layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"),
"layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"),
"layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"),
}
......@@ -491,7 +511,41 @@ model_configs_layout = {
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False)
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
qkv_layouts_thd = ['t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd']
model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_1": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_2": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_4": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"),
"layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"),
}
@pytest.mark.skipif(_cudnn_version() < (9,0,0), reason="cuDNN 9.0.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
@pytest.mark.parametrize("model", model_configs_layout_thd.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
pad_between_seqs = False
test_dot_product_attention(dtype, model_configs, model, False, True,
qkv_layout, False, pad_between_seqs)
pad_between_seqs = True
test_dot_product_attention(dtype, model_configs, model, False, True,
qkv_layout, False, pad_between_seqs)
def _run_dot_product_attention(
......@@ -502,6 +556,8 @@ def _run_dot_product_attention(
qkv_layout: str,
workspace_opt: bool,
swa: bool,
pad_between_seqs: bool,
is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
......@@ -537,6 +593,19 @@ def _run_dot_product_attention(
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
seqlens_q_after_pad = seqlens_q.clone()
seqlens_kv_after_pad = seqlens_kv.clone()
cu_seqlens_q_after_pad = cu_seqlens_q.clone()
cu_seqlens_kv_after_pad = cu_seqlens_kv.clone()
pad_len = [0] * config.batch_size
if pad_between_seqs:
max_pad_len = 3
pad_len = torch.randint(0, max_pad_len+1, [config.batch_size], device="cuda") #3
seqlens_q_after_pad = seqlens_q + pad_len
seqlens_kv_after_pad = seqlens_kv + pad_len
cu_seqlens_q_after_pad[1:] = torch.cumsum(seqlens_q_after_pad, dim=0)
cu_seqlens_kv_after_pad[1:] = torch.cumsum(seqlens_kv_after_pad, dim=0)
# Create attention mask if padding
attention_mask = None
if "padding" in config.attn_mask_type:
......@@ -582,13 +651,14 @@ def _run_dot_product_attention(
'h' : config.num_heads,
'hg' : config.num_gqa_groups,
'd' : config.head_dim,
't' : cu_seqlens_q[-1],
'tg' : cu_seqlens_kv[-1],
't' : cu_seqlens_q_after_pad[-1],
'tg' : cu_seqlens_kv_after_pad[-1],
'3' : 3,
'2' : 2,
'1' : 1,
}
inp = []
inp_orig = []
for i,layout in enumerate(qkv_layout.split('_')):
layout = '_'.join(layout)
if i == 0:
......@@ -599,6 +669,21 @@ def _run_dot_product_attention(
layout = layout.replace('t', 'tg')
tensor_shape = [dim_to_num[j] for j in layout.split('_')]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_orig = tensor
if qkv_format == 'thd' and pad_between_seqs:
tensor_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
if layout in ['t_h_d', 't_3_h_d', 't_h_3_d']:
for i in range(1, config.batch_size+1):
valid_range = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1])
pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i-1], cu_seqlens_q_after_pad[i])
tensor[pad_range[0]:pad_range[1]] = 0.0
tensor_orig = torch.cat([tensor_orig, tensor[valid_range[0]:valid_range[1]]], dim=0)
if layout in ['tg_hg_d', 'tg_2_hg_d', 'tg_hg_2_d']:
for i in range(1, config.batch_size+1):
valid_range = (cu_seqlens_kv_after_pad[i-1], cu_seqlens_kv_after_pad[i] - pad_len[i-1])
pad_range = (cu_seqlens_kv_after_pad[i] - pad_len[i-1], cu_seqlens_kv_after_pad[i])
tensor[pad_range[0]:pad_range[1]] = 0.0
tensor_orig = torch.cat([tensor_orig, tensor[valid_range[0]:valid_range[1]]], dim=0)
tensor_count = 1
split_dim = 0
for dim, l in enumerate(layout.split('_')):
......@@ -607,13 +692,35 @@ def _run_dot_product_attention(
split_dim = dim
break
tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
tensors_orig = torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig]
for j in range(tensor_count):
if split_dim != 0:
inp.append(tensors[j].squeeze(split_dim))
inp_orig.append(tensors_orig[j].squeeze(split_dim))
else:
inp.append(tensors[j])
inp_orig.append(tensors_orig[j])
for i in range(3):
inp[i].requires_grad = True
inp_orig[i].requires_grad = True
# Create ragged offsets for q/k/v
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = None, None, None, None
qkv_group = ''.join([x for x in qkv_layout if x not in 'bst'])
if qkv_format == 'thd':
seq_offsets_o = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
if qkv_group == 'hd_hd_hd':
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
if qkv_group in ['3hd', 'h3d']:
seq_offsets_q = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_k = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_v = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
if qkv_group in ['hd_2hd', 'hd_h2d']:
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
# Create output gradient
qkv_format_kv = '_'.join(qkv_format)
......@@ -621,6 +728,15 @@ def _run_dot_product_attention(
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
out_grad_orig = out_grad
if qkv_format == 'thd' and pad_between_seqs:
out_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
if qkv_format_kv == 't_h_d':
for i in range(1, config.batch_size+1):
valid_range = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1])
pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i-1], cu_seqlens_q_after_pad[i])
out_grad[pad_range[0]:pad_range[1]] = 0.0
out_grad_orig = torch.cat([out_grad_orig, out_grad[valid_range[0]:valid_range[1]]], dim=0)
# Create bias
if config.attn_bias_type in ['no_bias', 'alibi']:
......@@ -659,21 +775,64 @@ def _run_dot_product_attention(
)
# Run a forward and backward pass
out = block(inp[0], inp[1], inp[2],
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
q = inp_orig[0]
k = inp_orig[1]
v = inp_orig[2]
d_out = out_grad_orig
if backend == "FusedAttention":
q = inp[0]
k = inp[1]
v = inp[2]
d_out = out_grad
out = block(q, k, v,
window_size=window_size,
attention_mask=attention_mask,
qkv_format=qkv_format,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True)
out.backward(out_grad)
if is_training:
out.backward(d_out)
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if is_training:
return out, (q.grad, k.grad, v.grad)
else:
return out, (None, None, None)
if backend == "FusedAttention":
if qkv_format == 'thd' and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
q_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
k_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
v_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
for i in range(1, config.batch_size+1):
valid_range_q = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1])
valid_range_kv = (cu_seqlens_kv_after_pad[i-1], cu_seqlens_kv_after_pad[i] - pad_len[i-1])
out_orig = torch.cat([out_orig, out[valid_range_q[0]:valid_range_q[1]]], dim=0)
q_grad_orig = torch.cat([q_grad_orig, q.grad[valid_range_q[0]:valid_range_q[1]]], dim=0)
k_grad_orig = torch.cat([k_grad_orig, k.grad[valid_range_kv[0]:valid_range_kv[1]]], dim=0)
v_grad_orig = torch.cat([v_grad_orig, v.grad[valid_range_kv[0]:valid_range_kv[1]]], dim=0)
if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig)
else:
return out_orig, (None, None, None)
else:
if is_training:
return out, (q.grad, k.grad, v.grad)
else:
return out, (None, None, None)
model_configs_te_layer = {
......@@ -714,7 +873,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
)
flash_attn_supported = _is_flash_attention_supported(config)
unfused_attn_supported = _is_unfused_attention_supported(config)
unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
......@@ -1568,7 +1727,7 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:],
fp8_dtype_forward,
FusedAttnBackend["FP8"],
None,
None, None, None, None, None,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
......@@ -1648,6 +1807,7 @@ class _custom_mha_fp8(torch.autograd.Function):
fp8_dtype_backward,
ctx.aux_ctx_tensors,
FusedAttnBackend["FP8"],
None, None, None, None,
fwd_scale_inverses[META_QKV], # d_scale_qkv,
fwd_scale_inverses[META_S], # d_scale_s,
fwd_scale_inverses[META_O], # d_scale_o,
......
......@@ -135,20 +135,24 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
}
if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80)
|| (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)))
&& (max_seqlen_q % 64 == 0)
&& (max_seqlen_kv % 64 == 0)
&& ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0)
|| (cudnn_runtime_version >= 90000))
&& ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups)
|| (cudnn_runtime_version >= 8907))
&& ((head_dim <= 128) && (head_dim % 8 == 0))
&& ((head_dim <= 128 && head_dim % 8 == 0)
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// d=256 only supported for forward
|| (sm_arch_ >= 90 && cudnn_runtime_version >= 90000
&& head_dim <= 256 && head_dim % 8 == 0))
&& ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
|| ((cudnn_runtime_version >= 8906)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS
|| (bias_type == NVTE_Bias_Type::NVTE_ALIBI
&& attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK
&& attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK
&& sm_arch_ == 90)
&& sm_arch_ >= 90)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS
&& sm_arch_ == 90)))
&& sm_arch_ >= 90)))
|| ((cudnn_runtime_version >= 90000)
&& (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS
&& sm_arch_ >= 80)))
......@@ -163,6 +167,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
|| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
&& bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
&& ((qkv_format == NVTE_QKV_Format::NVTE_SBHD)
|| (sm_arch_ >= 90 && cudnn_runtime_version >= 90100
&& num_attn_heads == num_gqa_groups
&& qkv_format == NVTE_QKV_Format::NVTE_THD)
|| (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) {
flag_arb = true;
}
......@@ -211,6 +218,10 @@ void nvte_fused_attn_fwd_qkvpacked(
NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
const NVTETensor rng_state,
size_t max_seqlen,
bool is_training, float attn_scale, float dropout,
......@@ -222,6 +233,10 @@ void nvte_fused_attn_fwd_qkvpacked(
using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV);
const Tensor *input_Bias = reinterpret_cast<const Tensor*>(Bias);
......@@ -272,6 +287,7 @@ void nvte_fused_attn_fwd_qkvpacked(
input_QKV, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state,
wkspace, stream, handle);
#else
......@@ -306,6 +322,10 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor dQKV,
NVTETensor dBias,
const NVTETensor cu_seqlens,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
size_t max_seqlen,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......@@ -316,6 +336,10 @@ void nvte_fused_attn_bwd_qkvpacked(
using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO);
......@@ -377,7 +401,9 @@ void nvte_fused_attn_bwd_qkvpacked(
input_QKV, input_O, input_dO, input_Bias,
output_S,
output_dQKV, output_dBias,
input_cu_seqlens, input_rng_state,
input_cu_seqlens,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state,
wkspace, stream, handle);
#else
const char *err_msg =
......@@ -417,6 +443,10 @@ void nvte_fused_attn_fwd_kvpacked(
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout,
......@@ -428,6 +458,10 @@ void nvte_fused_attn_fwd_kvpacked(
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV);
......@@ -482,6 +516,7 @@ void nvte_fused_attn_fwd_kvpacked(
input_Q, input_KV, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state,
wkspace, stream, handle);
#else
......@@ -519,6 +554,10 @@ void nvte_fused_attn_bwd_kvpacked(
NVTETensor dBias,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......@@ -529,6 +568,10 @@ void nvte_fused_attn_bwd_kvpacked(
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
......@@ -596,6 +639,7 @@ void nvte_fused_attn_bwd_kvpacked(
output_S,
output_dQ, output_dKV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
......@@ -636,6 +680,10 @@ void nvte_fused_attn_fwd(
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout,
......@@ -647,6 +695,10 @@ void nvte_fused_attn_fwd(
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor*>(K);
......@@ -693,6 +745,7 @@ void nvte_fused_attn_fwd(
input_Q, input_K, input_V, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state,
wkspace, stream, handle);
#else
......@@ -732,6 +785,10 @@ void nvte_fused_attn_bwd(
NVTETensor dBias,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......@@ -742,6 +799,10 @@ void nvte_fused_attn_bwd(
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor*>(K);
const Tensor *input_V = reinterpret_cast<const Tensor*>(V);
......@@ -802,6 +863,7 @@ void nvte_fused_attn_bwd(
output_S,
output_dQ, output_dK, output_dV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
......
......@@ -57,9 +57,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
void *devPtrSoftmaxStats, void *devPtrO,
void* devPtrDropoutSeed, void* devPtrDropoutOffset,
void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV,
void* devPtrSeqOffsetsQ, void* devPtrSeqOffsetsK,
void* devPtrSeqOffsetsV, void* devPtrSeqOffsetsO,
cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
......@@ -67,6 +70,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (is_training && dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
if (is_ragged) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
}
try {
FADescriptor_v1 descriptor{b, h,
......@@ -89,6 +96,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_q
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_k
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_v
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_o
std::shared_ptr<fe::graph::Tensor_attributes>, // dropout_seed
std::shared_ptr<fe::graph::Tensor_attributes> >; // dropout_offset
......@@ -113,8 +124,30 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
offset_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_q")
.set_dim({b+1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
offset_k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_k")
.set_dim({b+1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
offset_v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_v")
.set_dim({b+1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
offset_o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_o")
.set_dim({b+1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4);
......@@ -124,18 +157,37 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_stride));
K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride));
V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride));
if (is_ragged) {
Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_stride)
.set_ragged_offset(offset_q));
K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride)
.set_ragged_offset(offset_k));
V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride)
.set_ragged_offset(offset_v));
} else {
Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_stride));
K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride));
V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride));
}
attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("attn_scale")
......@@ -197,7 +249,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride);
if (is_ragged) {
O->set_output(true)
.set_dim({b, h, s_q, d})
.set_stride(o_stride)
.set_ragged_offset(offset_o);
} else {
O->set_output(true)
.set_dim({b, h, s_q, d})
.set_stride(o_stride);
}
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1})
......@@ -213,11 +274,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto padding_tuple = is_padding ?
std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto offset_tuple = is_ragged ?
std::make_tuple(offset_q, offset_k, offset_v, offset_o) :
std::make_tuple(nullptr, nullptr, nullptr, nullptr);
auto dropout_tuple = is_dropout ?
std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr);
auto return_empty_tuple = std::tuple_cat(
std::make_tuple(nullptr), key_tensors_tuple,
Stats_tuple, bias_tuple, padding_tuple, dropout_tuple);
NVTE_CHECK_CUDNN_FE(mha_graph->validate());
NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle));
......@@ -227,18 +288,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple,
Stats_tuple, bias_tuple, padding_tuple, dropout_tuple);
Stats_tuple, bias_tuple, padding_tuple, offset_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
auto [mha_graph, Q, K, V, attn_scale, O, Stats,
bias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph(
bias, seq_q, seq_kv, offset_q, offset_k, offset_v, offset_o,
dropout_seed, dropout_offset] = get_graph(
sdpa_f16_fprop_cache, descriptor);
auto plan_workspace_size = mha_graph->get_workspace_size();
// Exit to request upper level API to allocate memory if needed
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
if (workspace == nullptr) {
......@@ -277,6 +338,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
variant_pack[seq_kv] = devActualSeqlenKV;
}
if (is_ragged) {
variant_pack[offset_q] = devPtrSeqOffsetsQ;
variant_pack[offset_k] = devPtrSeqOffsetsK;
variant_pack[offset_v] = devPtrSeqOffsetsV;
variant_pack[offset_o] = devPtrSeqOffsetsO;
}
if (is_dropout) {
variant_pack[dropout_seed] = devPtrDropoutSeed;
variant_pack[dropout_offset] = devPtrDropoutOffset;
......@@ -298,8 +366,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, void* devPtrdBias,
void* devPtrDropoutSeed, void* devPtrDropoutOffset,
void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV,
void* devPtrSeqOffsetsQ, void* devPtrSeqOffsetsK,
void* devPtrSeqOffsetsV, void* devPtrSeqOffsetsO,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
......@@ -307,6 +378,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
if (is_ragged) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
}
try {
FADescriptor_v1 descriptor{b, h,
......@@ -334,6 +409,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // dBias
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_q
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_k
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_v
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_o
std::shared_ptr<fe::graph::Tensor_attributes>, // dropout_seed
std::shared_ptr<fe::graph::Tensor_attributes> >; // dropout_offset
......@@ -358,8 +437,29 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes> q, k, v, o, dO, stats, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> bias, dBias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
offset_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_q")
.set_dim({b+1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
offset_k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_k")
.set_dim({b+1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
offset_v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_v")
.set_dim({b+1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
offset_o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_o")
.set_dim({b+1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4);
......@@ -372,26 +472,55 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_stride));
k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride));
v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride));
o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim({b, h, s_q, d})
.set_stride(o_stride));
dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO")
.set_dim({b, h, s_q, d})
.set_stride(o_stride));
if (is_ragged) {
q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_stride)
.set_ragged_offset(offset_q));
k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride)
.set_ragged_offset(offset_k));
v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride)
.set_ragged_offset(offset_v));
o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim({b, h, s_q, d})
.set_stride(o_stride)
.set_ragged_offset(offset_o));
dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO")
.set_dim({b, h, s_q, d})
.set_stride(o_stride)
.set_ragged_offset(offset_o));
} else {
q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_stride));
k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride));
v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride));
o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim({b, h, s_q, d})
.set_stride(o_stride));
dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO")
.set_dim({b, h, s_q, d})
.set_stride(o_stride));
}
stats = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("stats")
.set_dim({b, h, s_q, 1})
......@@ -465,15 +594,30 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto [dQ, dK, dV] = mha_graph->sdpa_backward(
q, k, v, o, dO, stats, sdpa_backward_options);
dQ->set_output(true)
.set_dim({b, h, s_q, d})
.set_stride(q_stride);
dK->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride);
dV->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride);
if (is_ragged) {
dQ->set_output(true)
.set_dim({b, h, s_q, d})
.set_stride(q_stride)
.set_ragged_offset(offset_q);
dK->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride)
.set_ragged_offset(offset_k);
dV->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride)
.set_ragged_offset(offset_v);
} else {
dQ->set_output(true)
.set_dim({b, h, s_q, d})
.set_stride(q_stride);
dK->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride);
dV->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride);
}
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // q
std::shared_ptr<fe::graph::Tensor_attributes>, // k
......@@ -490,11 +634,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr);
auto padding_tuple = is_padding ?
std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto offset_tuple = is_ragged ?
std::make_tuple(offset_q, offset_k, offset_v, offset_o) :
std::make_tuple(nullptr, nullptr, nullptr, nullptr);
auto dropout_tuple = is_dropout ?
std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr);
auto return_empty_tuple = std::tuple_cat(
std::make_tuple(nullptr), key_tensors_tuple,
bias_tuple, padding_tuple, dropout_tuple);
NVTE_CHECK_CUDNN_FE(mha_graph->validate());
NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle));
......@@ -504,14 +648,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple,
bias_tuple, padding_tuple, dropout_tuple);
bias_tuple, padding_tuple, offset_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV,
bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph(
bias, dBias, seq_q, seq_kv, offset_q, offset_k, offset_v, offset_o,
dropout_seed, dropout_offset] = get_graph(
sdpa_f16_bprop_cache, descriptor);
auto plan_workspace_size = mha_graph->get_workspace_size();
......@@ -564,6 +709,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
variant_pack[seq_kv] = devActualSeqlenKV;
}
if (is_ragged) {
variant_pack[offset_q] = devPtrSeqOffsetsQ;
variant_pack[offset_k] = devPtrSeqOffsetsK;
variant_pack[offset_v] = devPtrSeqOffsetsV;
variant_pack[offset_o] = devPtrSeqOffsetsO;
}
if (is_dropout) {
variant_pack[dropout_seed] = devPtrDropoutSeed;
variant_pack[dropout_offset] = devPtrDropoutOffset;
......@@ -581,8 +733,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *seq_offsets_q,
const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *seq_offsets_o,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_QKV->data.dtype;
......@@ -609,6 +762,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
......@@ -665,6 +822,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens,
devPtrSeqOffsetsQ, devPtrSeqOffsetsK,
devPtrSeqOffsetsV, devPtrSeqOffsetsO,
get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size,
stream, handle);
......@@ -690,9 +849,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *seq_offsets_q,
const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -732,6 +892,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>(
......@@ -747,6 +911,8 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens,
devPtrSeqOffsetsQ, devPtrSeqOffsetsK,
devPtrSeqOffsetsV, devPtrSeqOffsetsO,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
......@@ -769,8 +935,9 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q,
const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *seq_offsets_o,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -800,6 +967,10 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
......@@ -856,6 +1027,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsK,
devPtrSeqOffsetsV, devPtrSeqOffsetsO,
get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size,
stream, handle);
......@@ -885,9 +1058,10 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q,
const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
......@@ -926,6 +1100,10 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>(
......@@ -941,6 +1119,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsK,
devPtrSeqOffsetsV, devPtrSeqOffsetsO,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
......@@ -966,7 +1146,8 @@ void fused_attn_arbitrary_seqlen_fwd(
NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -987,6 +1168,10 @@ void fused_attn_arbitrary_seqlen_fwd(
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
......@@ -1043,6 +1228,8 @@ void fused_attn_arbitrary_seqlen_fwd(
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsK,
devPtrSeqOffsetsV, devPtrSeqOffsetsO,
get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size,
stream, handle);
......@@ -1072,11 +1259,11 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t
Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV,
Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q,
const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr;
......@@ -1102,6 +1289,10 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>(
......@@ -1116,6 +1307,8 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsK,
devPtrSeqOffsetsV, devPtrSeqOffsetsO,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
......
......@@ -24,8 +24,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
const Tensor *cu_seqlens, const Tensor *seq_offsets_q,
const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen,
......@@ -35,8 +37,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
const Tensor *cu_seqlens, const Tensor *seq_offsets_q,
const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
......@@ -47,7 +51,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
......@@ -59,7 +64,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(
......@@ -72,7 +78,8 @@ void fused_attn_arbitrary_seqlen_fwd(
const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(
......@@ -86,7 +93,8 @@ void fused_attn_arbitrary_seqlen_bwd(
Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900
......
......@@ -170,13 +170,35 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
| 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
\endverbatim
*
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
seq_offsets_q = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_k = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_v = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
*
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(seqlen_i) for i=0,...batch_size-1.
......@@ -196,6 +218,10 @@ void nvte_fused_attn_fwd_qkvpacked(
NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
const NVTETensor rng_state,
size_t max_seqlen,
bool is_training, float attn_scale, float dropout,
......@@ -214,6 +240,24 @@ void nvte_fused_attn_fwd_qkvpacked(
| 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
\endverbatim
*
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
seq_offsets_q = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_k = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_v = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
*
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
......@@ -223,7 +267,11 @@ void nvte_fused_attn_fwd_qkvpacked(
* e.g. M, ZInv, rng_state.
* \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(seqlen_i) for i=0,...batch_size-1.
* \param[in] attn_scale Scaling factor for Q * K.T.
......@@ -244,6 +292,10 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor dQKV,
NVTETensor dBias,
const NVTETensor cu_seqlens,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
size_t max_seqlen,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......@@ -266,6 +318,24 @@ void nvte_fused_attn_bwd_qkvpacked(
| 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
\endverbatim
*
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens_q
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
*
* \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in 2HD or H2D layouts.
* \param[in] Bias The Bias tensor.
......@@ -273,8 +343,12 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
......@@ -298,6 +372,10 @@ void nvte_fused_attn_fwd_kvpacked(
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout,
......@@ -315,6 +393,24 @@ void nvte_fused_attn_fwd_kvpacked(
| 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
\endverbatim
*
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens_q
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
*
* \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in H2D or 2HD layouts.
* \param[in] O The O tensor from forward.
......@@ -326,8 +422,12 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
......@@ -353,6 +453,10 @@ void nvte_fused_attn_bwd_kvpacked(
NVTETensor dBias,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......@@ -378,6 +482,34 @@ void nvte_fused_attn_bwd_kvpacked(
| 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
\endverbatim
*
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
qkv_group = nvte_get_qkv_layout_group(qkv_layout)
if qkv_group == 'hd_hd_hd':
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv
if qkv_group in ['3hd', 'h3d']:
seq_offsets_q = num_attn_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_k = num_attn_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_v = num_attn_heads * head_dim * 3 * cu_seqlens_q
if qkv_group in ['hd_2hd', 'hd_h2d']:
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens_q
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
*
* \param[in] Q The Q tensor.
* \param[in] K The K tensor.
* \param[in] V The V tensor.
......@@ -388,6 +520,10 @@ void nvte_fused_attn_bwd_kvpacked(
* e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
......@@ -412,6 +548,10 @@ void nvte_fused_attn_fwd(
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout,
......@@ -432,6 +572,34 @@ void nvte_fused_attn_fwd(
| 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
\endverbatim
*
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
qkv_group = nvte_get_qkv_layout_group(qkv_layout)
if qkv_group == 'hd_hd_hd':
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv
if qkv_group in ['3hd', 'h3d']:
seq_offsets_q = num_attn_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_k = num_attn_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_v = num_attn_heads * head_dim * 3 * cu_seqlens_q
if qkv_group in ['hd_2hd', 'hd_h2d']:
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens_q
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
*
* \param[in] Q The Q tensor.
* \param[in] K The K tensor.
* \param[in] V The V tensor.
......@@ -447,6 +615,10 @@ void nvte_fused_attn_fwd(
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
......@@ -474,6 +646,10 @@ void nvte_fused_attn_bwd(
NVTETensor dBias,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......
......@@ -1239,27 +1239,37 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
TensorWrapper query_workspace_tensor;
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
assert(q_max_seqlen == kv_max_seqlen);
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, query_workspace_tensor.data(), nullptr);
&aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
......@@ -1294,6 +1304,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
TensorWrapper query_workspace_tensor;
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
assert(q_max_seqlen == kv_max_seqlen);
auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim};
......@@ -1304,8 +1316,10 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
......@@ -1319,9 +1333,12 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
......@@ -1336,11 +1353,15 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(),
dv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
&aux_input_tensors,
dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
}
......@@ -1416,6 +1437,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
descriptor.wkspace_dtype);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
/* Call the underly NVTE API */
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
auto qkv = buffers[0];
......@@ -1423,9 +1446,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), rng_state_tensor.data(), q_max_seqlen,
descriptor.is_training, descriptor.scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
&aux_output_tensors, q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -1437,9 +1463,11 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -1453,9 +1481,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
descriptor.is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
......@@ -1496,15 +1527,20 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor;
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
&aux_input_tensors,
dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
......@@ -1574,6 +1610,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
/* Call the underly NVTE API */
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
auto qkv = buffers[0];
......@@ -1586,8 +1624,10 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -1605,9 +1645,11 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -1629,10 +1671,12 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(),
dv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
dv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
......
......@@ -647,10 +647,13 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
// create workspace
TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout,
te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(),
max_seqlen, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
// allocate memory for workspace and auxiliary output tensors
......@@ -664,7 +667,9 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
// execute the kernel
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout,
te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(),
max_seqlen, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
// destroy tensor wrappers, but not allocated memory
......@@ -730,10 +735,13 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
// create workspace
TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), max_seqlen, attn_scale, p_dropout,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
// allocate memory for workspace
......@@ -743,7 +751,9 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
// execute kernel
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), max_seqlen, attn_scale, p_dropout,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
// destroy tensor wrappers
......@@ -816,10 +826,13 @@ void te_fused_attn_fwd_kvpacked(
// create workspace
TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q,
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
......@@ -834,7 +847,9 @@ void te_fused_attn_fwd_kvpacked(
// execute the kernel
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q,
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
......@@ -909,11 +924,14 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
// create workspace
TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// allocate memory for workspace
......@@ -924,7 +942,9 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// destroy tensor wrappers
......@@ -989,10 +1009,13 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
// create workspace
TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv,
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
......@@ -1008,7 +1031,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
// execute the kernel
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv,
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
......@@ -1084,11 +1109,14 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
// create workspace
TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
......@@ -1100,7 +1128,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
......
......@@ -1683,8 +1683,6 @@ class UnfusedDotProductAttention(torch.nn.Module):
assert (qkv_layout in QKVLayouts
), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
assert (qkv_format != 'thd'
), """UnfusedDotProductAttention does not support variable sequence lengths!"""
if qkv_format == 'bshd':
# convert to sbhd and use sbhd implementation for now
query_layer, key_layer, value_layer = [x.transpose(0, 1)
......@@ -2067,7 +2065,7 @@ class FlashAttention(torch.nn.Module):
else:
query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
for x in (query_layer, key_layer, value_layer)]
elif qkv_format == 'bshd':
elif qkv_format in ['bshd', 'thd']:
query_layer, key_layer, value_layer = [x.contiguous()
for x in (query_layer, key_layer, value_layer)]
......@@ -2181,7 +2179,7 @@ class FlashAttention(torch.nn.Module):
**fa_optional_forward_kwargs,
)
if 'padding' in attn_mask_type:
if qkv_format in ['sbhd', 'bshd'] and 'padding' in attn_mask_type:
output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
if qkv_format == 'sbhd':
......@@ -2230,7 +2228,9 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
"""Function for FusedAttention with packed QKV input"""
@staticmethod
def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale,
def forward(ctx, is_training, max_seqlen, cu_seqlens,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
qkv, qkv_dtype, attn_bias, attn_scale,
dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen, fused_attention_backend, use_FAv2_bwd,
fp8, fp8_meta):
......@@ -2257,6 +2257,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training, max_seqlen, cu_seqlens,
qkv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
......@@ -2297,6 +2298,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype,
fused_attention_backend, attn_bias,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
None, None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen)
......@@ -2305,7 +2307,9 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens, *fp8_tensors, *aux_ctx_tensors)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
*fp8_tensors, *aux_ctx_tensors)
ctx.fp8_meta = fp8_meta
ctx.max_seqlen = max_seqlen
ctx.qkv_dtype = qkv_dtype
......@@ -2330,7 +2334,9 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
d_out = d_out._data
d_out = d_out.contiguous()
(qkv, out, cu_seqlens, qkv_fp8, out_fp8,
(qkv, out, cu_seqlens,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
qkv_fp8, out_fp8,
fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors
if not aux_ctx_tensors[0].is_contiguous():
aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
......@@ -2369,6 +2375,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
qkv_fp8, out_fp8, d_out_fp8,
fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors,
ctx.fused_attention_backend,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o,
......@@ -2404,17 +2411,18 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.max_seqlen, cu_seqlens, qkv, out, d_out,
ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors,
ctx.fused_attention_backend,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
None, None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]:
return (None, None, None, dqkv, None, None, None,
return (None, None, None, None, None, None, None, dqkv, None, None, None,
None, None, None, None, None, None,
None, None, None, None, None, None)
# else, return (dqkv, dbias)
return (None, None, None, dqkv, None, rest[0], None,
return (None, None, None, None, None, None, None, dqkv, None, rest[0], None,
None, None, None, None, None, None,
None, None, None, None, None, None)
......@@ -2424,6 +2432,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
@staticmethod
def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
use_FAv2_bwd, fp8, fp8_meta):
......@@ -2454,6 +2463,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q_fp8, kv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
......@@ -2497,6 +2507,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, qkv_dtype, fused_attention_backend, attn_bias,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
None, None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen)
......@@ -2506,6 +2517,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
*fp8_tensors, *aux_ctx_tensors)
ctx.fp8_meta = fp8_meta
ctx.max_seqlen_q = max_seqlen_q
......@@ -2532,7 +2544,9 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
d_out = d_out._data
d_out = d_out.contiguous()
(q, kv, out, cu_seqlens_q, cu_seqlens_kv, q_fp8, kv_fp8, out_fp8,
(q, kv, out, cu_seqlens_q, cu_seqlens_kv,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
q_fp8, kv_fp8, out_fp8,
fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors
if not aux_ctx_tensors[0].is_contiguous():
aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
......@@ -2573,6 +2587,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
q_fp8, kv_fp8, out_fp8, d_out_fp8,
fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors,
ctx.fused_attention_backend,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o,
......@@ -2620,17 +2635,18 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
q, kv, out, d_out,
ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors,
ctx.fused_attention_backend,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
None, None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]:
return (None, None, None, None, None, dq, dkv, None, None, None,
return (None, None, None, None, None, None, None, None, None, dq, dkv, None, None, None,
None, None, None, None, None, None,
None, None, None, None, None, None)
# else, return (dqkv, dbias)
return (None, None, None, None, None, dq, dkv, None, rest[0], None,
return (None, None, None, None, None, None, None, None, None, dq, dkv, None, rest[0], None,
None, None, None, None, None, None,
None, None, None, None, None, None)
......@@ -2639,6 +2655,7 @@ class FusedAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
use_FAv2_bwd, fp8, fp8_meta):
......@@ -2690,6 +2707,7 @@ class FusedAttnFunc(torch.autograd.Function):
out_fp8, aux_ctx_tensors = fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q_fp8, k_fp8, v_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
......@@ -2761,6 +2779,7 @@ class FusedAttnFunc(torch.autograd.Function):
out_ret, aux_ctx_tensors = fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, k, v, qkv_dtype, fused_attention_backend, attn_bias,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
None, None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen)
......@@ -2778,6 +2797,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
*fp8_tensors, *aux_ctx_tensors)
ctx.fp8_meta = fp8_meta
ctx.max_seqlen_q = max_seqlen_q
......@@ -2804,7 +2824,9 @@ class FusedAttnFunc(torch.autograd.Function):
d_out = d_out._data
d_out = d_out.contiguous()
(q, k, v, out, cu_seqlens_q, cu_seqlens_kv, q_fp8, k_fp8, v_fp8, out_fp8,
(q, k, v, out, cu_seqlens_q, cu_seqlens_kv,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
q_fp8, k_fp8, v_fp8, out_fp8,
fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors
if not aux_ctx_tensors[0].is_contiguous():
aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
......@@ -2846,6 +2868,7 @@ class FusedAttnFunc(torch.autograd.Function):
q_fp8, k_fp8, v_fp8, out_fp8, d_out_fp8,
fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors,
ctx.fused_attention_backend,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o,
......@@ -2929,17 +2952,20 @@ class FusedAttnFunc(torch.autograd.Function):
q, k, v, out, d_out,
ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors,
ctx.fused_attention_backend,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
None, None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]:
return (None, None, None, None, None, dq, dk, dv, None, None, None,
return (None, None, None, None, None, None,
None, None, None, dq, dk, dv, None, None, None,
None, None, None, None, None, None,
None, None, None, None, None, None)
# else, return (dqkv, dbias)
return (None, None, None, None, None, dq, dk, dv, None, rest[0], None,
return (None, None, None, None, None, None,
None, None, None, dq, dk, dv, None, rest[0], None,
None, None, None, None, None, None,
None, None, None, None, None, None)
......@@ -3032,6 +3058,10 @@ class FusedAttention(TransformerEngineBaseModule):
qkv_layout: str = "sbh3d",
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
seq_offsets_q: Optional[torch.Tensor] = None,
seq_offsets_k: Optional[torch.Tensor] = None,
seq_offsets_v: Optional[torch.Tensor] = None,
seq_offsets_o: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
attn_mask_type: str = "causal",
......@@ -3047,7 +3077,6 @@ class FusedAttention(TransformerEngineBaseModule):
is_first_microbatch: Optional[bool] = None,
) -> torch.Tensor:
"""fused attention fprop"""
assert (fused_attention_backend
!= tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
), 'No fused attention backend supports this input combination!'
......@@ -3066,9 +3095,6 @@ class FusedAttention(TransformerEngineBaseModule):
context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
assert (
qkv_format != 'thd'
), 'FusedAttention does not support qkv_format = thd!'
if qkv_format in ['sbhd', 'bshd']:
if qkv_format == 'sbhd':
......@@ -3104,6 +3130,34 @@ class FusedAttention(TransformerEngineBaseModule):
max_seqlen_kv,
key_layer.device,
)
if qkv_format == 'thd':
assert not context_parallel, "thd format not supported with context parallelism!"
assert (max_seqlen_q is not None
and max_seqlen_kv is not None
and cu_seqlens_q is not None
and cu_seqlens_kv is not None
), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
if (seq_offsets_q is None
or seq_offsets_k is None
or seq_offsets_v is None
or seq_offsets_o is None):
qkv_group = ''.join([x for x in qkv_layout if x not in 'bst'])
num_heads = query_layer.shape[-2]
num_gqa_groups = key_layer.shape[-2]
head_dim = query_layer.shape[-1]
seq_offsets_o = num_heads * head_dim * cu_seqlens_q
if qkv_group == 'hd_hd_hd':
seq_offsets_q = num_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv
if qkv_group in ['3hd', 'h3d']:
seq_offsets_q = num_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_k = num_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_v = num_heads * head_dim * 3 * cu_seqlens_q
if qkv_group in ['hd_2hd', 'hd_h2d']:
seq_offsets_q = num_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
qkv_dtype = TE_DType[query_layer.dtype]
......@@ -3157,6 +3211,7 @@ class FusedAttention(TransformerEngineBaseModule):
self.training,
max_seqlen_q, max_seqlen_kv,
cu_seqlens_q, cu_seqlens_kv,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
query_layer, key_layer, value_layer,
qkv_dtype,
core_attention_bias,
......@@ -3430,6 +3485,10 @@ class DotProductAttention(torch.nn.Module):
qkv_format: Optional[str] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
seq_offsets_q: Optional[torch.Tensor] = None,
seq_offsets_k: Optional[torch.Tensor] = None,
seq_offsets_v: Optional[torch.Tensor] = None,
seq_offsets_o: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
attn_mask_type: Optional[str] = None,
......@@ -3511,6 +3570,18 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
with shape [batch_size + 1] and dtype torch.int32.
seq_offsets_q: Optional[torch.Tensor], default = `None`
Cumulative offset of different sequences in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
seq_offsets_k: Optional[torch.Tensor], default = `None`
Cumulative offset of different sequences in a batch for `key_layer`,
with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
seq_offsets_v: Optional[torch.Tensor], default = `None`
Cumulative offset of different sequences in a batch for `value_layer`,
with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
seq_offsets_o: Optional[torch.Tensor], default = `None`
Cumulative offset of different sequences in a batch for forward output,
with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided.
......@@ -3581,6 +3652,9 @@ class DotProductAttention(torch.nn.Module):
assert (attn_mask_type in AttnMaskTypes
), f"Attention mask type {attn_mask_type} is not supported!"
if qkv_format == 'thd':
assert ('padding' in attn_mask_type
), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
if self.rng_states_tracker is not None and is_graph_capturing():
assert (
......@@ -3649,10 +3723,10 @@ class DotProductAttention(torch.nn.Module):
), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
if max_seqlen_q is None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_q = seqlens_q.max().item()
max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item())))
if max_seqlen_kv is None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = seqlens_kv.max().item()
max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item())))
if qkv_format in ['sbhd', 'bshd']:
assert (all(len(x.shape) == 4 for x in (query_layer, key_layer, value_layer))
......@@ -3690,6 +3764,10 @@ class DotProductAttention(torch.nn.Module):
# The following section filters out some backends based on
# certain asserts before executing the forward pass.
# Filter: QKV layout.
if qkv_format == 'thd':
use_unfused_attention = False
# Filter: ONNX export.
if is_in_onnx_export_mode():
use_flash_attention = False
......@@ -3891,6 +3969,10 @@ class DotProductAttention(torch.nn.Module):
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type,
......@@ -3910,6 +3992,10 @@ class DotProductAttention(torch.nn.Module):
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type,
......
......@@ -83,6 +83,10 @@ def fused_attn_fwd_qkvpacked(
qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
......@@ -118,6 +122,14 @@ def fused_attn_fwd_qkvpacked(
attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -225,6 +237,7 @@ def fused_attn_fwd_qkvpacked(
max_seqlen, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, qkv_dtype,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias,
rng_gen, rng_elts_per_thread,
)
......@@ -243,6 +256,10 @@ def fused_attn_bwd_qkvpacked(
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
......@@ -286,6 +303,14 @@ def fused_attn_bwd_qkvpacked(
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -361,6 +386,7 @@ def fused_attn_bwd_qkvpacked(
max_seqlen, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
)
......@@ -379,6 +405,10 @@ def fused_attn_fwd_kvpacked(
qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
......@@ -421,6 +451,14 @@ def fused_attn_fwd_kvpacked(
attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -529,6 +567,7 @@ def fused_attn_fwd_kvpacked(
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o,
attn_bias, rng_gen, rng_elts_per_thread,
)
......@@ -550,6 +589,10 @@ def fused_attn_bwd_kvpacked(
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
......@@ -600,6 +643,14 @@ def fused_attn_bwd_kvpacked(
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -679,6 +730,7 @@ def fused_attn_bwd_kvpacked(
max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
)
......@@ -698,6 +750,10 @@ def fused_attn_fwd(
qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
......@@ -744,6 +800,14 @@ def fused_attn_fwd(
attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -854,6 +918,7 @@ def fused_attn_fwd(
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o,
attn_bias, rng_gen, rng_elts_per_thread,
)
......@@ -876,6 +941,10 @@ def fused_attn_bwd(
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
......@@ -929,6 +998,14 @@ def fused_attn_bwd(
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -1012,6 +1089,7 @@ def fused_attn_bwd(
max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
)
......
......@@ -31,6 +31,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S,
......@@ -54,6 +58,10 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
......@@ -76,6 +84,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
const at::Tensor Q,
const at::Tensor KV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S,
......@@ -101,6 +113,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
......@@ -124,6 +140,10 @@ std::vector<at::Tensor> fused_attn_fwd(
const at::Tensor K,
const at::Tensor V,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S,
......@@ -150,6 +170,10 @@ std::vector<at::Tensor> fused_attn_bwd(
const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
......
......@@ -96,6 +96,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S,
......@@ -123,6 +127,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
// construct NVTE tensors
TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens;
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
auto h = q_shape[q_shape.size() - 2];
......@@ -169,6 +174,32 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape,
DType::kInt32, nullptr, nullptr, nullptr);
if ((seq_offsets_q.has_value())
&& (seq_offsets_k.has_value())
&& (seq_offsets_v.has_value())
&& (seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{
seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{
seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{
seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{
seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(),
seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(),
seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(),
seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(),
seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr);
}
// extract random number generator seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
......@@ -193,6 +224,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
te_O.data(),
&nvte_aux_tensor_pack,
te_cu_seqlens.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
te_rng_state.data(),
max_seqlen,
is_training, attn_scale, p_dropout,
......@@ -241,6 +276,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
te_O.data(),
&nvte_aux_tensor_pack,
te_cu_seqlens.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
te_rng_state.data(),
max_seqlen,
is_training, attn_scale, p_dropout,
......@@ -266,6 +305,10 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
......@@ -380,6 +423,33 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
TensorWrapper te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape,
DType::kInt32, nullptr, nullptr, nullptr);
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
if ((seq_offsets_q.has_value())
&& (seq_offsets_k.has_value())
&& (seq_offsets_v.has_value())
&& (seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{
seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{
seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{
seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{
seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(),
seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(),
seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(),
seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(),
seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr);
}
// create workspace
TensorWrapper workspace;
......@@ -394,6 +464,10 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
te_dQKV.data(),
te_dBias.data(),
te_cu_seqlens.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
max_seqlen,
attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type,
......@@ -417,6 +491,10 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
te_dQKV.data(),
te_dBias.data(),
te_cu_seqlens.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
max_seqlen,
attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type,
......@@ -439,6 +517,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
const at::Tensor Q,
const at::Tensor KV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S,
......@@ -462,6 +544,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
// construct NVTE tensors
TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv;
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
auto h = q_shape[q_shape.size() - 2];
......@@ -516,6 +599,32 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr);
if ((seq_offsets_q.has_value())
&& (seq_offsets_k.has_value())
&& (seq_offsets_v.has_value())
&& (seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{
seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{
seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{
seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{
seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(),
seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(),
seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(),
seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(),
seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr);
}
// extract rng seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
......@@ -542,6 +651,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
&nvte_aux_tensor_pack,
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
te_rng_state.data(),
max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout,
......@@ -592,6 +705,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
&nvte_aux_tensor_pack,
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
te_rng_state.data(),
max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout,
......@@ -620,6 +737,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
......@@ -725,6 +846,33 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr);
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
if ((seq_offsets_q.has_value())
&& (seq_offsets_k.has_value())
&& (seq_offsets_v.has_value())
&& (seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{
seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{
seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{
seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{
seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(),
seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(),
seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(),
seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(),
seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr);
}
// convert auxiliary tensors from forward to NVTETensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
......@@ -771,6 +919,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_dBias.data(),
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type,
......@@ -797,6 +949,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_dBias.data(),
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type,
......@@ -820,6 +976,10 @@ std::vector<at::Tensor> fused_attn_fwd(
const at::Tensor K,
const at::Tensor V,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S,
......@@ -844,6 +1004,7 @@ std::vector<at::Tensor> fused_attn_fwd(
// construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias;
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
auto h = q_shape[q_shape.size() - 2];
......@@ -902,6 +1063,32 @@ std::vector<at::Tensor> fused_attn_fwd(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr);
if ((seq_offsets_q.has_value())
&& (seq_offsets_k.has_value())
&& (seq_offsets_v.has_value())
&& (seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{
seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{
seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{
seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{
seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(),
seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(),
seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(),
seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(),
seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr);
}
// extract rng seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
......@@ -930,6 +1117,10 @@ std::vector<at::Tensor> fused_attn_fwd(
&nvte_aux_tensor_pack,
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
te_rng_state.data(),
max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout,
......@@ -981,6 +1172,10 @@ std::vector<at::Tensor> fused_attn_fwd(
&nvte_aux_tensor_pack,
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
te_rng_state.data(),
max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout,
......@@ -1010,6 +1205,10 @@ std::vector<at::Tensor> fused_attn_bwd(
const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
......@@ -1183,6 +1382,33 @@ std::vector<at::Tensor> fused_attn_bwd(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr);
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
if ((seq_offsets_q.has_value())
&& (seq_offsets_k.has_value())
&& (seq_offsets_v.has_value())
&& (seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{
seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{
seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{
seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{
seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(),
seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(),
seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(),
seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(),
seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr);
}
// convert auxiliary tensors from forward to NVTETensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
......@@ -1231,6 +1457,10 @@ std::vector<at::Tensor> fused_attn_bwd(
te_dBias.data(),
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type,
......@@ -1259,6 +1489,10 @@ std::vector<at::Tensor> fused_attn_bwd(
te_dBias.data(),
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment