Unverified Commit e9606077 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[C/PyTorch] Add THD support for cuDNN attention (#832)



* add THD support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add seq_offsets_o and use new offset calculation
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* addition to previous commit; fix unit test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add None for offset_o gradient
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: test padding between sequences
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: fix tests for padding between sequences
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix tests for sbhd/bshd layouts; clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update cudnn-frontend and add tests for max_seqlen_q=1 and d=256 for inference
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test sbhd/bshd layouts for sq1, d256 inference case
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace wording from accumulative to cumulative
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add offset tensors to custom fp8 mha tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add version control for cuDNN
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add sm>=90 constraint for thd support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix cuDNN support for sq=1, d=256
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint and minor tweak for fp8 tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* modify cudnn version and restrict MQA/GQA support for THD
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add notes for seq offset tensors
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add dummy tensor to pass jax build
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add dummy tensor to pass paddle build
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix Jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent e8a17d1e
Subproject commit 1b0b5eac540b7f8fd19b18f1e6b8427c95503348 Subproject commit b740542818f36857acf7f9853f749bbad4118c65
...@@ -37,7 +37,7 @@ from transformer_engine.jax.softmax import SoftmaxType ...@@ -37,7 +37,7 @@ from transformer_engine.jax.softmax import SoftmaxType
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
DATA_SHAPE = [(128, 32, 512), (512, 32, 512)] DATA_SHAPE = [(32, 128, 512), (32, 512, 512)] # (B, S, H)
DTYPE = [jnp.float32, jnp.bfloat16] DTYPE = [jnp.float32, jnp.bfloat16]
ENABLE_FP8 = [False, True] ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID] FP8_FORMATS = [Format.E4M3, Format.HYBRID]
...@@ -736,7 +736,7 @@ class TestDotProductAttn(TestLayer): ...@@ -736,7 +736,7 @@ class TestDotProductAttn(TestLayer):
q_key, k_key, v_key = jax.random.split(key, 3) q_key, k_key, v_key = jax.random.split(key, 3)
b, s, *_ = shape b, s, *_ = shape
if self.attrs[DotProductAttnAttr.TRANSPOSE_BS]: if self.attrs[DotProductAttnAttr.TRANSPOSE_BS]:
b, s = s, b shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [ return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask
...@@ -786,6 +786,7 @@ class MultiHeadAttnAttr: ...@@ -786,6 +786,7 @@ class MultiHeadAttnAttr:
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = 'zero_centered_gamma'
NUM_ATTN_HEADS = 'num_attention_heads' NUM_ATTN_HEADS = 'num_attention_heads'
NUM_GQA_GROUPS = 'num_gqa_groups' NUM_GQA_GROUPS = 'num_gqa_groups'
TRANSPOSE_BS = 'transpose_batch_sequence'
ENABLE_ROPE = 'enable_rotary_pos_emb' ENABLE_ROPE = 'enable_rotary_pos_emb'
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method' ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
LORA_SCOPE = 'low_rank_adaptation_scope' LORA_SCOPE = 'low_rank_adaptation_scope'
...@@ -795,42 +796,48 @@ class MultiHeadAttnAttr: ...@@ -795,42 +796,48 @@ class MultiHeadAttnAttr:
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding' ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: True,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
ZERO_CEN: True, ZERO_CEN: True,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding' ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: False,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: 'rmsnorm',
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding' ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: True,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal' ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: False,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
ZERO_CEN: True, ZERO_CEN: True,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal' ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: True,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: 'rmsnorm',
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal' ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: False,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: 'rmsnorm',
...@@ -839,7 +846,8 @@ class MultiHeadAttnAttr: ...@@ -839,7 +846,8 @@ class MultiHeadAttnAttr:
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
NUM_ATTN_HEADS: 8, NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4, NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal' ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: True,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: 'rmsnorm',
...@@ -848,7 +856,8 @@ class MultiHeadAttnAttr: ...@@ -848,7 +856,8 @@ class MultiHeadAttnAttr:
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
NUM_ATTN_HEADS: 8, NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4, NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal' ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: False,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: 'rmsnorm',
...@@ -857,7 +866,8 @@ class MultiHeadAttnAttr: ...@@ -857,7 +866,8 @@ class MultiHeadAttnAttr:
ROPE_GROUP_METHOD: 'alternate', ROPE_GROUP_METHOD: 'alternate',
NUM_ATTN_HEADS: 8, NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4, NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal' ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: True,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
...@@ -865,7 +875,8 @@ class MultiHeadAttnAttr: ...@@ -865,7 +875,8 @@ class MultiHeadAttnAttr:
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding', ATTN_MASK_TYPE: 'padding',
LORA_SCOPE: 'all' LORA_SCOPE: 'all',
TRANSPOSE_BS: False,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
...@@ -873,7 +884,8 @@ class MultiHeadAttnAttr: ...@@ -873,7 +884,8 @@ class MultiHeadAttnAttr:
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal', ATTN_MASK_TYPE: 'causal',
LORA_SCOPE: 'all' LORA_SCOPE: 'all',
TRANSPOSE_BS: True,
}] }]
...@@ -882,7 +894,9 @@ class TestMultiHeadAttn(TestLayer): ...@@ -882,7 +894,9 @@ class TestMultiHeadAttn(TestLayer):
def input_getter(self, shape, dtype): def input_getter(self, shape, dtype):
key = jax.random.PRNGKey(seed=1234) key = jax.random.PRNGKey(seed=1234)
q_key, kv_key = jax.random.split(key, 2) q_key, kv_key = jax.random.split(key, 2)
s, b, *_ = shape b, s, *_ = shape
if self.attrs[MultiHeadAttnAttr.TRANSPOSE_BS]:
shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask] return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask]
...@@ -906,7 +920,7 @@ class TestMultiHeadAttn(TestLayer): ...@@ -906,7 +920,7 @@ class TestMultiHeadAttn(TestLayer):
rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD] rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none') low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none')
fuse_qkv_params = True fuse_qkv_params = True
transpose_batch_sequence = True transpose_batch_sequence = attrs[MultiHeadAttnAttr.TRANSPOSE_BS]
scale_attn_logits = False scale_attn_logits = False
scaled_query_init = True scaled_query_init = True
float32_logits = False float32_logits = False
...@@ -962,6 +976,7 @@ class TestMultiHeadAttn(TestLayer): ...@@ -962,6 +976,7 @@ class TestMultiHeadAttn(TestLayer):
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS) @pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
...@@ -977,7 +992,7 @@ class TestMultiHeadAttn(TestLayer): ...@@ -977,7 +992,7 @@ class TestMultiHeadAttn(TestLayer):
fp8_format, fp8_format,
rtol=1e-05, rtol=1e-05,
atol=1e-08): atol=1e-08):
self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format) ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
...@@ -1240,7 +1255,7 @@ class TestTransformer(TestLayer): ...@@ -1240,7 +1255,7 @@ class TestTransformer(TestLayer):
q_key, kv_key = jax.random.split(key, 2) q_key, kv_key = jax.random.split(key, 2)
b, s, *_ = shape b, s, *_ = shape
if self.attrs[TransformerLayerAttr.TRANSPOSE_BS]: if self.attrs[TransformerLayerAttr.TRANSPOSE_BS]:
b, s = s, b shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [ return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask
......
...@@ -194,13 +194,17 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool: ...@@ -194,13 +194,17 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool:
return False return False
return True return True
def _is_unfused_attention_supported(
def _is_unfused_attention_supported(config: ModelConfig) -> bool: config: ModelConfig,
qkv_format: str,
) -> bool:
"""Check if UnfusedDotProductAttention supports a model configuration""" """Check if UnfusedDotProductAttention supports a model configuration"""
if ("padding" in config.attn_mask_type): if ("padding" in config.attn_mask_type):
return False return False
if ("causal" in config.attn_mask_type and config.attn_type == 'cross'): if ("causal" in config.attn_mask_type and config.attn_type == 'cross'):
return False return False
if qkv_format == 'thd':
return False
return True return True
...@@ -210,6 +214,8 @@ model_configs_base = { ...@@ -210,6 +214,8 @@ model_configs_base = {
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0 "base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1 "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1 "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
} }
...@@ -239,7 +245,9 @@ def get_swa(seq_q, seq_kv, w=None): ...@@ -239,7 +245,9 @@ def get_swa(seq_q, seq_kv, w=None):
@pytest.mark.parametrize("workspace_opt", [True, False]) @pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None]) @pytest.mark.parametrize("qkv_layout", [None])
@pytest.mark.parametrize("swa", [False]) @pytest.mark.parametrize("swa", [False])
def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa): @pytest.mark.parametrize("pad_between_seqs", [False])
def test_dot_product_attention(dtype, model_configs, model, ckpt_attn,
workspace_opt, qkv_layout, swa, pad_between_seqs):
"""Test DotProductAttention module""" """Test DotProductAttention module"""
# Get configs # Get configs
...@@ -258,7 +266,8 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace ...@@ -258,7 +266,8 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
) )
# Skip if only unfused backend is supported # Skip if only unfused backend is supported
unfused_attn_supported = _is_unfused_attention_supported(config) qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format)
if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512: if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
...@@ -269,14 +278,19 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace ...@@ -269,14 +278,19 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
flash_attn_supported = _is_flash_attention_supported(config) flash_attn_supported = _is_flash_attention_supported(config)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2: if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.") pytest.skip("Less than two backends to compare.")
if (qkv_format == 'thd' and 'padding' not in config.attn_mask_type):
pytest.skip("THD layout requires padding/padding_causal mask type.")
# d=256 is supported by cuDNN 9.0+ for inference but not training
is_training = (config.head_dim <= 128)
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
if unfused_attn_supported: if unfused_attn_supported:
if swa: if swa:
attn_mask_type = config.attn_mask_type attn_mask_type = config.attn_mask_type
config.attn_mask_type = "arbitrary" config.attn_mask_type = "arbitrary"
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype, config, "UnfusedDotProductAttention", ckpt_attn, qkv_layout, workspace_opt, swa, dtype, config, "UnfusedDotProductAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
) )
if swa: if swa:
config.attn_mask_type = attn_mask_type config.attn_mask_type = attn_mask_type
...@@ -285,22 +299,26 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace ...@@ -285,22 +299,26 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
if fused_attn_supported: if fused_attn_supported:
if len(fused_attn_backend) == 1: if len(fused_attn_backend) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa, dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
) )
if len(fused_attn_backend) == 2: if len(fused_attn_backend) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa, dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
) )
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa, dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
) )
# FlashAttention backend # FlashAttention backend
if flash_attn_supported: if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention", ckpt_attn, qkv_layout, workspace_opt, swa, dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
) )
if unfused_attn_supported and fused_attn_supported: if unfused_attn_supported and fused_attn_supported:
...@@ -335,7 +353,7 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace ...@@ -335,7 +353,7 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"]) @pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model): def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing""" """Test DotProductAttention module with checkpointing"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False) test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
model_configs_mask = { model_configs_mask = {
...@@ -361,7 +379,7 @@ model_configs_mask = { ...@@ -361,7 +379,7 @@ model_configs_mask = {
@pytest.mark.parametrize("model", model_configs_mask.keys()) @pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model): def test_dpa_mask(dtype, model_configs, model):
"""Test DotProductAttention module with different mask types""" """Test DotProductAttention module with different mask types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False) test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
model_configs_bias = { model_configs_bias = {
...@@ -399,7 +417,7 @@ model_configs_bias = { ...@@ -399,7 +417,7 @@ model_configs_bias = {
@pytest.mark.parametrize("model", model_configs_bias.keys()) @pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model): def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types""" """Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False) test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
model_configs_bias_shapes = { model_configs_bias_shapes = {
...@@ -426,7 +444,8 @@ model_configs_bias_shapes = { ...@@ -426,7 +444,8 @@ model_configs_bias_shapes = {
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys()) @pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
def test_dpa_bias_shapes(dtype, model_configs, model): def test_dpa_bias_shapes(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types and shapes""" """Test DotProductAttention module with different bias types and shapes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False) test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
model_configs_swa = { model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
...@@ -443,7 +462,8 @@ model_configs_swa = { ...@@ -443,7 +462,8 @@ model_configs_swa = {
@pytest.mark.parametrize("model", model_configs_swa.keys()) @pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model): def test_dpa_sliding_window(dtype, model_configs, model):
"""Test DotProductAttention module with sliding window attention""" """Test DotProductAttention module with sliding window attention"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, True) test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False)
model_configs_alibi_slopes = { model_configs_alibi_slopes = {
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type # test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
...@@ -460,14 +480,12 @@ model_configs_alibi_slopes = { ...@@ -460,14 +480,12 @@ model_configs_alibi_slopes = {
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys()) @pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
def test_dpa_alibi_slopes(dtype, model_configs, model): def test_dpa_alibi_slopes(dtype, model_configs, model):
"""Test DotProductAttention module with ALiBi slopes""" """Test DotProductAttention module with ALiBi slopes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False) test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
qkv_layouts = [ qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd', 'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd', 'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
# will add tests for thd layouts later when the support is available in fused attention
#'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd',
] ]
...@@ -481,6 +499,8 @@ model_configs_layout = { ...@@ -481,6 +499,8 @@ model_configs_layout = {
"layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), "layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), "layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), "layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"),
"layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"),
"layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"),
} }
...@@ -491,7 +511,41 @@ model_configs_layout = { ...@@ -491,7 +511,41 @@ model_configs_layout = {
@pytest.mark.parametrize("qkv_layout", qkv_layouts) @pytest.mark.parametrize("qkv_layout", qkv_layouts)
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts""" """Test DotProductAttention module with different QKV layouts"""
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False) test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
qkv_layouts_thd = ['t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd']
model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_1": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_2": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_4": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"),
"layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"),
}
@pytest.mark.skipif(_cudnn_version() < (9,0,0), reason="cuDNN 9.0.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
@pytest.mark.parametrize("model", model_configs_layout_thd.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
pad_between_seqs = False
test_dot_product_attention(dtype, model_configs, model, False, True,
qkv_layout, False, pad_between_seqs)
pad_between_seqs = True
test_dot_product_attention(dtype, model_configs, model, False, True,
qkv_layout, False, pad_between_seqs)
def _run_dot_product_attention( def _run_dot_product_attention(
...@@ -502,6 +556,8 @@ def _run_dot_product_attention( ...@@ -502,6 +556,8 @@ def _run_dot_product_attention(
qkv_layout: str, qkv_layout: str,
workspace_opt: bool, workspace_opt: bool,
swa: bool, swa: bool,
pad_between_seqs: bool,
is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass""" """Run DotProductAttention module with one forward pass and one backward pass"""
...@@ -537,6 +593,19 @@ def _run_dot_product_attention( ...@@ -537,6 +593,19 @@ def _run_dot_product_attention(
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0) cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
seqlens_q_after_pad = seqlens_q.clone()
seqlens_kv_after_pad = seqlens_kv.clone()
cu_seqlens_q_after_pad = cu_seqlens_q.clone()
cu_seqlens_kv_after_pad = cu_seqlens_kv.clone()
pad_len = [0] * config.batch_size
if pad_between_seqs:
max_pad_len = 3
pad_len = torch.randint(0, max_pad_len+1, [config.batch_size], device="cuda") #3
seqlens_q_after_pad = seqlens_q + pad_len
seqlens_kv_after_pad = seqlens_kv + pad_len
cu_seqlens_q_after_pad[1:] = torch.cumsum(seqlens_q_after_pad, dim=0)
cu_seqlens_kv_after_pad[1:] = torch.cumsum(seqlens_kv_after_pad, dim=0)
# Create attention mask if padding # Create attention mask if padding
attention_mask = None attention_mask = None
if "padding" in config.attn_mask_type: if "padding" in config.attn_mask_type:
...@@ -582,13 +651,14 @@ def _run_dot_product_attention( ...@@ -582,13 +651,14 @@ def _run_dot_product_attention(
'h' : config.num_heads, 'h' : config.num_heads,
'hg' : config.num_gqa_groups, 'hg' : config.num_gqa_groups,
'd' : config.head_dim, 'd' : config.head_dim,
't' : cu_seqlens_q[-1], 't' : cu_seqlens_q_after_pad[-1],
'tg' : cu_seqlens_kv[-1], 'tg' : cu_seqlens_kv_after_pad[-1],
'3' : 3, '3' : 3,
'2' : 2, '2' : 2,
'1' : 1, '1' : 1,
} }
inp = [] inp = []
inp_orig = []
for i,layout in enumerate(qkv_layout.split('_')): for i,layout in enumerate(qkv_layout.split('_')):
layout = '_'.join(layout) layout = '_'.join(layout)
if i == 0: if i == 0:
...@@ -599,6 +669,21 @@ def _run_dot_product_attention( ...@@ -599,6 +669,21 @@ def _run_dot_product_attention(
layout = layout.replace('t', 'tg') layout = layout.replace('t', 'tg')
tensor_shape = [dim_to_num[j] for j in layout.split('_')] tensor_shape = [dim_to_num[j] for j in layout.split('_')]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda") tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_orig = tensor
if qkv_format == 'thd' and pad_between_seqs:
tensor_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
if layout in ['t_h_d', 't_3_h_d', 't_h_3_d']:
for i in range(1, config.batch_size+1):
valid_range = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1])
pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i-1], cu_seqlens_q_after_pad[i])
tensor[pad_range[0]:pad_range[1]] = 0.0
tensor_orig = torch.cat([tensor_orig, tensor[valid_range[0]:valid_range[1]]], dim=0)
if layout in ['tg_hg_d', 'tg_2_hg_d', 'tg_hg_2_d']:
for i in range(1, config.batch_size+1):
valid_range = (cu_seqlens_kv_after_pad[i-1], cu_seqlens_kv_after_pad[i] - pad_len[i-1])
pad_range = (cu_seqlens_kv_after_pad[i] - pad_len[i-1], cu_seqlens_kv_after_pad[i])
tensor[pad_range[0]:pad_range[1]] = 0.0
tensor_orig = torch.cat([tensor_orig, tensor[valid_range[0]:valid_range[1]]], dim=0)
tensor_count = 1 tensor_count = 1
split_dim = 0 split_dim = 0
for dim, l in enumerate(layout.split('_')): for dim, l in enumerate(layout.split('_')):
...@@ -607,13 +692,35 @@ def _run_dot_product_attention( ...@@ -607,13 +692,35 @@ def _run_dot_product_attention(
split_dim = dim split_dim = dim
break break
tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor] tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
tensors_orig = torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig]
for j in range(tensor_count): for j in range(tensor_count):
if split_dim != 0: if split_dim != 0:
inp.append(tensors[j].squeeze(split_dim)) inp.append(tensors[j].squeeze(split_dim))
inp_orig.append(tensors_orig[j].squeeze(split_dim))
else: else:
inp.append(tensors[j]) inp.append(tensors[j])
inp_orig.append(tensors_orig[j])
for i in range(3): for i in range(3):
inp[i].requires_grad = True inp[i].requires_grad = True
inp_orig[i].requires_grad = True
# Create ragged offsets for q/k/v
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = None, None, None, None
qkv_group = ''.join([x for x in qkv_layout if x not in 'bst'])
if qkv_format == 'thd':
seq_offsets_o = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
if qkv_group == 'hd_hd_hd':
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
if qkv_group in ['3hd', 'h3d']:
seq_offsets_q = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_k = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_v = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
if qkv_group in ['hd_2hd', 'hd_h2d']:
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
# Create output gradient # Create output gradient
qkv_format_kv = '_'.join(qkv_format) qkv_format_kv = '_'.join(qkv_format)
...@@ -621,6 +728,15 @@ def _run_dot_product_attention( ...@@ -621,6 +728,15 @@ def _run_dot_product_attention(
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')] out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda") out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
out_grad_orig = out_grad
if qkv_format == 'thd' and pad_between_seqs:
out_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
if qkv_format_kv == 't_h_d':
for i in range(1, config.batch_size+1):
valid_range = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1])
pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i-1], cu_seqlens_q_after_pad[i])
out_grad[pad_range[0]:pad_range[1]] = 0.0
out_grad_orig = torch.cat([out_grad_orig, out_grad[valid_range[0]:valid_range[1]]], dim=0)
# Create bias # Create bias
if config.attn_bias_type in ['no_bias', 'alibi']: if config.attn_bias_type in ['no_bias', 'alibi']:
...@@ -659,21 +775,64 @@ def _run_dot_product_attention( ...@@ -659,21 +775,64 @@ def _run_dot_product_attention(
) )
# Run a forward and backward pass # Run a forward and backward pass
out = block(inp[0], inp[1], inp[2], if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
q = inp_orig[0]
k = inp_orig[1]
v = inp_orig[2]
d_out = out_grad_orig
if backend == "FusedAttention":
q = inp[0]
k = inp[1]
v = inp[2]
d_out = out_grad
out = block(q, k, v,
window_size=window_size, window_size=window_size,
attention_mask=attention_mask, attention_mask=attention_mask,
qkv_format=qkv_format, qkv_format=qkv_format,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
attn_mask_type=config.attn_mask_type, attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn, checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias, core_attention_bias=bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
fast_zero_fill=True) fast_zero_fill=True)
out.backward(out_grad) if is_training:
out.backward(d_out)
return out, (inp[0].grad, inp[1].grad, inp[2].grad) if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if is_training:
return out, (q.grad, k.grad, v.grad)
else:
return out, (None, None, None)
if backend == "FusedAttention":
if qkv_format == 'thd' and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
q_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
k_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
v_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
for i in range(1, config.batch_size+1):
valid_range_q = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1])
valid_range_kv = (cu_seqlens_kv_after_pad[i-1], cu_seqlens_kv_after_pad[i] - pad_len[i-1])
out_orig = torch.cat([out_orig, out[valid_range_q[0]:valid_range_q[1]]], dim=0)
q_grad_orig = torch.cat([q_grad_orig, q.grad[valid_range_q[0]:valid_range_q[1]]], dim=0)
k_grad_orig = torch.cat([k_grad_orig, k.grad[valid_range_kv[0]:valid_range_kv[1]]], dim=0)
v_grad_orig = torch.cat([v_grad_orig, v.grad[valid_range_kv[0]:valid_range_kv[1]]], dim=0)
if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig)
else:
return out_orig, (None, None, None)
else:
if is_training:
return out, (q.grad, k.grad, v.grad)
else:
return out, (None, None, None)
model_configs_te_layer = { model_configs_te_layer = {
...@@ -714,7 +873,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f ...@@ -714,7 +873,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd", qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
) )
flash_attn_supported = _is_flash_attention_supported(config) flash_attn_supported = _is_flash_attention_supported(config)
unfused_attn_supported = _is_unfused_attention_supported(config) unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2: if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.") pytest.skip("Less than two backends to compare.")
...@@ -1568,7 +1727,7 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -1568,7 +1727,7 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:], qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:],
fp8_dtype_forward, fp8_dtype_forward,
FusedAttnBackend["FP8"], FusedAttnBackend["FP8"],
None, None, None, None, None, None,
fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale[META_S],
...@@ -1648,6 +1807,7 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -1648,6 +1807,7 @@ class _custom_mha_fp8(torch.autograd.Function):
fp8_dtype_backward, fp8_dtype_backward,
ctx.aux_ctx_tensors, ctx.aux_ctx_tensors,
FusedAttnBackend["FP8"], FusedAttnBackend["FP8"],
None, None, None, None,
fwd_scale_inverses[META_QKV], # d_scale_qkv, fwd_scale_inverses[META_QKV], # d_scale_qkv,
fwd_scale_inverses[META_S], # d_scale_s, fwd_scale_inverses[META_S], # d_scale_s,
fwd_scale_inverses[META_O], # d_scale_o, fwd_scale_inverses[META_O], # d_scale_o,
......
...@@ -135,20 +135,24 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -135,20 +135,24 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
} }
if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80)
|| (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) || (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)))
&& (max_seqlen_q % 64 == 0) && ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0)
&& (max_seqlen_kv % 64 == 0) || (cudnn_runtime_version >= 90000))
&& ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) && ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups)
|| (cudnn_runtime_version >= 8907)) || (cudnn_runtime_version >= 8907))
&& ((head_dim <= 128) && (head_dim % 8 == 0)) && ((head_dim <= 128 && head_dim % 8 == 0)
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// d=256 only supported for forward
|| (sm_arch_ >= 90 && cudnn_runtime_version >= 90000
&& head_dim <= 256 && head_dim % 8 == 0))
&& ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
|| ((cudnn_runtime_version >= 8906) || ((cudnn_runtime_version >= 8906)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS
|| (bias_type == NVTE_Bias_Type::NVTE_ALIBI || (bias_type == NVTE_Bias_Type::NVTE_ALIBI
&& attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK
&& attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK
&& sm_arch_ == 90) && sm_arch_ >= 90)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS
&& sm_arch_ == 90))) && sm_arch_ >= 90)))
|| ((cudnn_runtime_version >= 90000) || ((cudnn_runtime_version >= 90000)
&& (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS
&& sm_arch_ >= 80))) && sm_arch_ >= 80)))
...@@ -163,6 +167,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -163,6 +167,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
|| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
&& bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
&& ((qkv_format == NVTE_QKV_Format::NVTE_SBHD) && ((qkv_format == NVTE_QKV_Format::NVTE_SBHD)
|| (sm_arch_ >= 90 && cudnn_runtime_version >= 90100
&& num_attn_heads == num_gqa_groups
&& qkv_format == NVTE_QKV_Format::NVTE_THD)
|| (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) { || (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) {
flag_arb = true; flag_arb = true;
} }
...@@ -211,6 +218,10 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -211,6 +218,10 @@ void nvte_fused_attn_fwd_qkvpacked(
NVTETensor O, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
const NVTETensor rng_state, const NVTETensor rng_state,
size_t max_seqlen, size_t max_seqlen,
bool is_training, float attn_scale, float dropout, bool is_training, float attn_scale, float dropout,
...@@ -222,6 +233,10 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -222,6 +233,10 @@ void nvte_fused_attn_fwd_qkvpacked(
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens); const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV); const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV);
const Tensor *input_Bias = reinterpret_cast<const Tensor*>(Bias); const Tensor *input_Bias = reinterpret_cast<const Tensor*>(Bias);
...@@ -272,6 +287,7 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -272,6 +287,7 @@ void nvte_fused_attn_fwd_qkvpacked(
input_QKV, input_Bias, output_O, input_QKV, input_Bias, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
input_cu_seqlens, input_cu_seqlens,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, input_rng_state,
wkspace, stream, handle); wkspace, stream, handle);
#else #else
...@@ -306,6 +322,10 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -306,6 +322,10 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor dQKV, NVTETensor dQKV,
NVTETensor dBias, NVTETensor dBias,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
size_t max_seqlen, size_t max_seqlen,
float attn_scale, float dropout, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -316,6 +336,10 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -316,6 +336,10 @@ void nvte_fused_attn_bwd_qkvpacked(
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens); const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV); const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O); const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO); const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO);
...@@ -377,7 +401,9 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -377,7 +401,9 @@ void nvte_fused_attn_bwd_qkvpacked(
input_QKV, input_O, input_dO, input_Bias, input_QKV, input_O, input_dO, input_Bias,
output_S, output_S,
output_dQKV, output_dBias, output_dQKV, output_dBias,
input_cu_seqlens, input_rng_state, input_cu_seqlens,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state,
wkspace, stream, handle); wkspace, stream, handle);
#else #else
const char *err_msg = const char *err_msg =
...@@ -417,6 +443,10 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -417,6 +443,10 @@ void nvte_fused_attn_fwd_kvpacked(
NVTETensorPack* Aux_CTX_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
const NVTETensor rng_state, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout, bool is_training, float attn_scale, float dropout,
...@@ -428,6 +458,10 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -428,6 +458,10 @@ void nvte_fused_attn_fwd_kvpacked(
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV); const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV);
...@@ -482,6 +516,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -482,6 +516,7 @@ void nvte_fused_attn_fwd_kvpacked(
input_Q, input_KV, input_Bias, output_O, input_Q, input_KV, input_Bias, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, input_rng_state,
wkspace, stream, handle); wkspace, stream, handle);
#else #else
...@@ -519,6 +554,10 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -519,6 +554,10 @@ void nvte_fused_attn_bwd_kvpacked(
NVTETensor dBias, NVTETensor dBias,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -529,6 +568,10 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -529,6 +568,10 @@ void nvte_fused_attn_bwd_kvpacked(
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV); const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O); const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
...@@ -596,6 +639,7 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -596,6 +639,7 @@ void nvte_fused_attn_bwd_kvpacked(
output_S, output_S,
output_dQ, output_dKV, output_dBias, output_dQ, output_dKV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, wkspace, stream, handle); input_rng_state, wkspace, stream, handle);
#else #else
const char *err_msg = const char *err_msg =
...@@ -636,6 +680,10 @@ void nvte_fused_attn_fwd( ...@@ -636,6 +680,10 @@ void nvte_fused_attn_fwd(
NVTETensorPack* Aux_CTX_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
const NVTETensor rng_state, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout, bool is_training, float attn_scale, float dropout,
...@@ -647,6 +695,10 @@ void nvte_fused_attn_fwd( ...@@ -647,6 +695,10 @@ void nvte_fused_attn_fwd(
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor*>(K); const Tensor *input_K = reinterpret_cast<const Tensor*>(K);
...@@ -693,6 +745,7 @@ void nvte_fused_attn_fwd( ...@@ -693,6 +745,7 @@ void nvte_fused_attn_fwd(
input_Q, input_K, input_V, input_Bias, output_O, input_Q, input_K, input_V, input_Bias, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, input_rng_state,
wkspace, stream, handle); wkspace, stream, handle);
#else #else
...@@ -732,6 +785,10 @@ void nvte_fused_attn_bwd( ...@@ -732,6 +785,10 @@ void nvte_fused_attn_bwd(
NVTETensor dBias, NVTETensor dBias,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -742,6 +799,10 @@ void nvte_fused_attn_bwd( ...@@ -742,6 +799,10 @@ void nvte_fused_attn_bwd(
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor*>(K); const Tensor *input_K = reinterpret_cast<const Tensor*>(K);
const Tensor *input_V = reinterpret_cast<const Tensor*>(V); const Tensor *input_V = reinterpret_cast<const Tensor*>(V);
...@@ -802,6 +863,7 @@ void nvte_fused_attn_bwd( ...@@ -802,6 +863,7 @@ void nvte_fused_attn_bwd(
output_S, output_S,
output_dQ, output_dK, output_dV, output_dBias, output_dQ, output_dK, output_dV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, wkspace, stream, handle); input_rng_state, wkspace, stream, handle);
#else #else
const char *err_msg = const char *err_msg =
......
...@@ -57,9 +57,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -57,9 +57,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
void *devPtrSoftmaxStats, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrO,
void* devPtrDropoutSeed, void* devPtrDropoutOffset, void* devPtrDropoutSeed, void* devPtrDropoutOffset,
void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV,
void* devPtrSeqOffsetsQ, void* devPtrSeqOffsetsK,
void* devPtrSeqOffsetsV, void* devPtrSeqOffsetsO,
cudnn_frontend::DataType_t tensorType, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
...@@ -67,6 +70,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -67,6 +70,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_dropout = (is_training && dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
if (is_ragged) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
}
try { try {
FADescriptor_v1 descriptor{b, h, FADescriptor_v1 descriptor{b, h,
...@@ -89,6 +96,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -89,6 +96,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // bias std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_q
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_k
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_v
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_o
std::shared_ptr<fe::graph::Tensor_attributes>, // dropout_seed std::shared_ptr<fe::graph::Tensor_attributes>, // dropout_seed
std::shared_ptr<fe::graph::Tensor_attributes> >; // dropout_offset std::shared_ptr<fe::graph::Tensor_attributes> >; // dropout_offset
...@@ -113,8 +124,30 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -113,8 +124,30 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale; std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv; std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset; std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
offset_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_q")
.set_dim({b+1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
offset_k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_k")
.set_dim({b+1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
offset_v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_v")
.set_dim({b+1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
offset_o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_o")
.set_dim({b+1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
std::vector<int64_t> q_stride(4); std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4); std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4); std::vector<int64_t> v_stride(4);
...@@ -124,18 +157,37 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -124,18 +157,37 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout, NVTE_QKV_Matrix::NVTE_K_Matrix); layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_V_Matrix); layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q") if (is_ragged) {
.set_dim({b, h, s_q, d}) Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_stride(q_stride)); .set_name("Q")
K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_dim({b, h, s_q, d})
.set_name("K") .set_stride(q_stride)
.set_dim({b, hg, s_kv, d}) .set_ragged_offset(offset_q));
.set_stride(k_stride)); K = mha_graph->tensor(fe::graph::Tensor_attributes()
V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K")
.set_name("V") .set_dim({b, hg, s_kv, d})
.set_dim({b, hg, s_kv, d}) .set_stride(k_stride)
.set_stride(v_stride)); .set_ragged_offset(offset_k));
V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride)
.set_ragged_offset(offset_v));
} else {
Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_stride));
K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride));
V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride));
}
attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("attn_scale") .set_name("attn_scale")
...@@ -197,7 +249,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -197,7 +249,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::vector<int64_t> o_stride(4); std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_O_Matrix); layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); if (is_ragged) {
O->set_output(true)
.set_dim({b, h, s_q, d})
.set_stride(o_stride)
.set_ragged_offset(offset_o);
} else {
O->set_output(true)
.set_dim({b, h, s_q, d})
.set_stride(o_stride);
}
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT) Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1}) .set_dim({b, h, s_q, 1})
...@@ -213,11 +274,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -213,11 +274,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto padding_tuple = is_padding ? auto padding_tuple = is_padding ?
std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto offset_tuple = is_ragged ?
std::make_tuple(offset_q, offset_k, offset_v, offset_o) :
std::make_tuple(nullptr, nullptr, nullptr, nullptr);
auto dropout_tuple = is_dropout ? auto dropout_tuple = is_dropout ?
std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr);
auto return_empty_tuple = std::tuple_cat(
std::make_tuple(nullptr), key_tensors_tuple,
Stats_tuple, bias_tuple, padding_tuple, dropout_tuple);
NVTE_CHECK_CUDNN_FE(mha_graph->validate()); NVTE_CHECK_CUDNN_FE(mha_graph->validate());
NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle));
...@@ -227,18 +288,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -227,18 +288,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
auto return_tuple = std::tuple_cat( auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple, std::make_tuple(mha_graph), key_tensors_tuple,
Stats_tuple, bias_tuple, padding_tuple, dropout_tuple); Stats_tuple, bias_tuple, padding_tuple, offset_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple}); cache.insert({descriptor, return_tuple});
return return_tuple; return return_tuple;
}; };
auto [mha_graph, Q, K, V, attn_scale, O, Stats, auto [mha_graph, Q, K, V, attn_scale, O, Stats,
bias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph( bias, seq_q, seq_kv, offset_q, offset_k, offset_v, offset_o,
dropout_seed, dropout_offset] = get_graph(
sdpa_f16_fprop_cache, descriptor); sdpa_f16_fprop_cache, descriptor);
auto plan_workspace_size = mha_graph->get_workspace_size(); auto plan_workspace_size = mha_graph->get_workspace_size();
// Exit to request upper level API to allocate memory if needed // Exit to request upper level API to allocate memory if needed
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
if (workspace == nullptr) { if (workspace == nullptr) {
...@@ -277,6 +338,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -277,6 +338,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
variant_pack[seq_kv] = devActualSeqlenKV; variant_pack[seq_kv] = devActualSeqlenKV;
} }
if (is_ragged) {
variant_pack[offset_q] = devPtrSeqOffsetsQ;
variant_pack[offset_k] = devPtrSeqOffsetsK;
variant_pack[offset_v] = devPtrSeqOffsetsV;
variant_pack[offset_o] = devPtrSeqOffsetsO;
}
if (is_dropout) { if (is_dropout) {
variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_seed] = devPtrDropoutSeed;
variant_pack[dropout_offset] = devPtrDropoutOffset; variant_pack[dropout_offset] = devPtrDropoutOffset;
...@@ -298,8 +366,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -298,8 +366,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, void* devPtrdBias, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, void* devPtrdBias,
void* devPtrDropoutSeed, void* devPtrDropoutOffset, void* devPtrDropoutSeed, void* devPtrDropoutOffset,
void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV,
void* devPtrSeqOffsetsQ, void* devPtrSeqOffsetsK,
void* devPtrSeqOffsetsV, void* devPtrSeqOffsetsO,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
...@@ -307,6 +378,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -307,6 +378,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f); bool is_dropout = (dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
if (is_ragged) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
}
try { try {
FADescriptor_v1 descriptor{b, h, FADescriptor_v1 descriptor{b, h,
...@@ -334,6 +409,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -334,6 +409,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // dBias std::shared_ptr<fe::graph::Tensor_attributes>, // dBias
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_q
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_k
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_v
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_o
std::shared_ptr<fe::graph::Tensor_attributes>, // dropout_seed std::shared_ptr<fe::graph::Tensor_attributes>, // dropout_seed
std::shared_ptr<fe::graph::Tensor_attributes> >; // dropout_offset std::shared_ptr<fe::graph::Tensor_attributes> >; // dropout_offset
...@@ -358,8 +437,29 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -358,8 +437,29 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes> q, k, v, o, dO, stats, attn_scale; std::shared_ptr<fe::graph::Tensor_attributes> q, k, v, o, dO, stats, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> bias, dBias, seq_q, seq_kv; std::shared_ptr<fe::graph::Tensor_attributes> bias, dBias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset; std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
offset_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_q")
.set_dim({b+1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
offset_k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_k")
.set_dim({b+1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
offset_v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_v")
.set_dim({b+1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
offset_o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_o")
.set_dim({b+1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
std::vector<int64_t> q_stride(4); std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4); std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4); std::vector<int64_t> v_stride(4);
...@@ -372,26 +472,55 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -372,26 +472,55 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
layout, NVTE_QKV_Matrix::NVTE_V_Matrix); layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_O_Matrix); layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q") if (is_ragged) {
.set_dim({b, h, s_q, d}) q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_stride(q_stride)); .set_name("Q")
k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_dim({b, h, s_q, d})
.set_name("K") .set_stride(q_stride)
.set_dim({b, hg, s_kv, d}) .set_ragged_offset(offset_q));
.set_stride(k_stride)); k = mha_graph->tensor(fe::graph::Tensor_attributes()
v = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K")
.set_name("V") .set_dim({b, hg, s_kv, d})
.set_dim({b, hg, s_kv, d}) .set_stride(k_stride)
.set_stride(v_stride)); .set_ragged_offset(offset_k));
o = mha_graph->tensor(fe::graph::Tensor_attributes() v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O") .set_name("V")
.set_dim({b, h, s_q, d}) .set_dim({b, hg, s_kv, d})
.set_stride(o_stride)); .set_stride(v_stride)
dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_ragged_offset(offset_v));
.set_name("dO") o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_dim({b, h, s_q, d}) .set_name("O")
.set_stride(o_stride)); .set_dim({b, h, s_q, d})
.set_stride(o_stride)
.set_ragged_offset(offset_o));
dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO")
.set_dim({b, h, s_q, d})
.set_stride(o_stride)
.set_ragged_offset(offset_o));
} else {
q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_stride));
k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride));
v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride));
o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim({b, h, s_q, d})
.set_stride(o_stride));
dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO")
.set_dim({b, h, s_q, d})
.set_stride(o_stride));
}
stats = mha_graph->tensor(fe::graph::Tensor_attributes() stats = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("stats") .set_name("stats")
.set_dim({b, h, s_q, 1}) .set_dim({b, h, s_q, 1})
...@@ -465,15 +594,30 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -465,15 +594,30 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto [dQ, dK, dV] = mha_graph->sdpa_backward( auto [dQ, dK, dV] = mha_graph->sdpa_backward(
q, k, v, o, dO, stats, sdpa_backward_options); q, k, v, o, dO, stats, sdpa_backward_options);
dQ->set_output(true) if (is_ragged) {
.set_dim({b, h, s_q, d}) dQ->set_output(true)
.set_stride(q_stride); .set_dim({b, h, s_q, d})
dK->set_output(true) .set_stride(q_stride)
.set_dim({b, hg, s_kv, d}) .set_ragged_offset(offset_q);
.set_stride(k_stride); dK->set_output(true)
dV->set_output(true) .set_dim({b, hg, s_kv, d})
.set_dim({b, hg, s_kv, d}) .set_stride(k_stride)
.set_stride(v_stride); .set_ragged_offset(offset_k);
dV->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride)
.set_ragged_offset(offset_v);
} else {
dQ->set_output(true)
.set_dim({b, h, s_q, d})
.set_stride(q_stride);
dK->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride);
dV->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride);
}
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // q std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // q
std::shared_ptr<fe::graph::Tensor_attributes>, // k std::shared_ptr<fe::graph::Tensor_attributes>, // k
...@@ -490,11 +634,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -490,11 +634,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr);
auto padding_tuple = is_padding ? auto padding_tuple = is_padding ?
std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto offset_tuple = is_ragged ?
std::make_tuple(offset_q, offset_k, offset_v, offset_o) :
std::make_tuple(nullptr, nullptr, nullptr, nullptr);
auto dropout_tuple = is_dropout ? auto dropout_tuple = is_dropout ?
std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr);
auto return_empty_tuple = std::tuple_cat(
std::make_tuple(nullptr), key_tensors_tuple,
bias_tuple, padding_tuple, dropout_tuple);
NVTE_CHECK_CUDNN_FE(mha_graph->validate()); NVTE_CHECK_CUDNN_FE(mha_graph->validate());
NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle));
...@@ -504,14 +648,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -504,14 +648,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto return_tuple = std::tuple_cat( auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple, std::make_tuple(mha_graph), key_tensors_tuple,
bias_tuple, padding_tuple, dropout_tuple); bias_tuple, padding_tuple, offset_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple}); cache.insert({descriptor, return_tuple});
return return_tuple; return return_tuple;
}; };
auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV,
bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph( bias, dBias, seq_q, seq_kv, offset_q, offset_k, offset_v, offset_o,
dropout_seed, dropout_offset] = get_graph(
sdpa_f16_bprop_cache, descriptor); sdpa_f16_bprop_cache, descriptor);
auto plan_workspace_size = mha_graph->get_workspace_size(); auto plan_workspace_size = mha_graph->get_workspace_size();
...@@ -564,6 +709,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -564,6 +709,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
variant_pack[seq_kv] = devActualSeqlenKV; variant_pack[seq_kv] = devActualSeqlenKV;
} }
if (is_ragged) {
variant_pack[offset_q] = devPtrSeqOffsetsQ;
variant_pack[offset_k] = devPtrSeqOffsetsK;
variant_pack[offset_v] = devPtrSeqOffsetsV;
variant_pack[offset_o] = devPtrSeqOffsetsO;
}
if (is_dropout) { if (is_dropout) {
variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_seed] = devPtrDropoutSeed;
variant_pack[dropout_offset] = devPtrDropoutOffset; variant_pack[dropout_offset] = devPtrDropoutOffset;
...@@ -581,8 +733,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -581,8 +733,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *seq_offsets_q,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *seq_offsets_o,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_QKV->data.dtype; const auto QKV_type = input_QKV->data.dtype;
...@@ -609,6 +762,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -609,6 +762,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
void *devPtrO = output_O->data.dptr; void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr; void *devPtrS = nullptr;
void *devPtrCuSeqlens = cu_seqlens->data.dptr; void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
...@@ -665,6 +822,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -665,6 +822,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens, devPtrCuSeqlens, devPtrCuSeqlens,
devPtrSeqOffsetsQ, devPtrSeqOffsetsK,
devPtrSeqOffsetsV, devPtrSeqOffsetsO,
get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, workspace->data.dptr, &workspace_size,
stream, handle); stream, handle);
...@@ -690,9 +849,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea ...@@ -690,9 +849,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
Tensor *output_dQKV, Tensor *output_dBias, const Tensor *cu_seqlens, const Tensor *seq_offsets_q,
const Tensor *cu_seqlens, const Tensor *rng_state, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -732,6 +892,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea ...@@ -732,6 +892,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
devPtrSoftmaxStats = output_S->data.dptr; devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrCuSeqlens = cu_seqlens->data.dptr; void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr; void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>( void* devPtrDropoutOffset = reinterpret_cast<void *>(
...@@ -747,6 +911,8 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea ...@@ -747,6 +911,8 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens, devPtrCuSeqlens, devPtrCuSeqlens,
devPtrSeqOffsetsQ, devPtrSeqOffsetsK,
devPtrSeqOffsetsV, devPtrSeqOffsetsO,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle); &workspace_size, stream, handle);
...@@ -769,8 +935,9 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -769,8 +935,9 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_KV, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q,
const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *seq_offsets_o,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -800,6 +967,10 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -800,6 +967,10 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
...@@ -856,6 +1027,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -856,6 +1027,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsK,
devPtrSeqOffsetsV, devPtrSeqOffsetsO,
get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, workspace->data.dptr, &workspace_size,
stream, handle); stream, handle);
...@@ -885,9 +1058,10 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -885,9 +1058,10 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
const Tensor *input_Bias, Tensor *output_S, const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, const Tensor *cu_seqlens_q, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q,
const Tensor *rng_state, Tensor *workspace, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
cudaStream_t stream, cudnnHandle_t handle) { const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
...@@ -926,6 +1100,10 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -926,6 +1100,10 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr; void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>( void* devPtrDropoutOffset = reinterpret_cast<void *>(
...@@ -941,6 +1119,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -941,6 +1119,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsK,
devPtrSeqOffsetsV, devPtrSeqOffsetsO,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle); &workspace_size, stream, handle);
...@@ -966,7 +1146,8 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -966,7 +1146,8 @@ void fused_attn_arbitrary_seqlen_fwd(
NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -987,6 +1168,10 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -987,6 +1168,10 @@ void fused_attn_arbitrary_seqlen_fwd(
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
...@@ -1043,6 +1228,8 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1043,6 +1228,8 @@ void fused_attn_arbitrary_seqlen_fwd(
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsK,
devPtrSeqOffsetsV, devPtrSeqOffsetsO,
get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, workspace->data.dptr, &workspace_size,
stream, handle); stream, handle);
...@@ -1072,11 +1259,11 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t ...@@ -1072,11 +1259,11 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t
Tensor *output_S, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV,
Tensor *output_dBias, const Tensor *cu_seqlens_q, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q,
const Tensor *rng_state, Tensor *workspace, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
cudaStream_t stream, cudnnHandle_t handle) { const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr; void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr; void *devPtrK = input_K->data.dptr;
...@@ -1102,6 +1289,10 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t ...@@ -1102,6 +1289,10 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr; void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>( void* devPtrDropoutOffset = reinterpret_cast<void *>(
...@@ -1116,6 +1307,8 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t ...@@ -1116,6 +1307,8 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsK,
devPtrSeqOffsetsV, devPtrSeqOffsetsO,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle); &workspace_size, stream, handle);
......
...@@ -24,8 +24,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -24,8 +24,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_QKV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *rng_state, const Tensor *cu_seqlens, const Tensor *seq_offsets_q,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t batch, size_t num_attn_heads, size_t max_seqlen,
...@@ -35,8 +37,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( ...@@ -35,8 +37,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *rng_state, const Tensor *cu_seqlens, const Tensor *seq_offsets_q,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
...@@ -47,7 +51,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -47,7 +51,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_bwd_kvpacked(
...@@ -59,7 +64,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -59,7 +64,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_fwd(
...@@ -72,7 +78,8 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -72,7 +78,8 @@ void fused_attn_arbitrary_seqlen_fwd(
const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd( void fused_attn_arbitrary_seqlen_bwd(
...@@ -86,7 +93,8 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -86,7 +93,8 @@ void fused_attn_arbitrary_seqlen_bwd(
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias, Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900 #endif // CUDNN_VERSION >= 8900
......
...@@ -170,13 +170,35 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -170,13 +170,35 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
| 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 | | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
\endverbatim \endverbatim
* *
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
seq_offsets_q = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_k = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_v = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
*
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD. * \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor. * \param[in,out] S The S tensor.
* \param[out] O The output O tensor. * \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state. * e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen Max sequence length used for computing, * \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(seqlen_i) for i=0,...batch_size-1. * it may be >= max(seqlen_i) for i=0,...batch_size-1.
...@@ -196,6 +218,10 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -196,6 +218,10 @@ void nvte_fused_attn_fwd_qkvpacked(
NVTETensor O, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
const NVTETensor rng_state, const NVTETensor rng_state,
size_t max_seqlen, size_t max_seqlen,
bool is_training, float attn_scale, float dropout, bool is_training, float attn_scale, float dropout,
...@@ -214,6 +240,24 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -214,6 +240,24 @@ void nvte_fused_attn_fwd_qkvpacked(
| 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 | | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
\endverbatim \endverbatim
* *
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
seq_offsets_q = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_k = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_v = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
*
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD. * \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] O The O tensor from forward. * \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor. * \param[in] dO The gradient of the O tensor.
...@@ -223,7 +267,11 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -223,7 +267,11 @@ void nvte_fused_attn_fwd_qkvpacked(
* e.g. M, ZInv, rng_state. * e.g. M, ZInv, rng_state.
* \param[out] dQKV The gradient of the QKV tensor. * \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing, * \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(seqlen_i) for i=0,...batch_size-1. * it may be >= max(seqlen_i) for i=0,...batch_size-1.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
...@@ -244,6 +292,10 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -244,6 +292,10 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor dQKV, NVTETensor dQKV,
NVTETensor dBias, NVTETensor dBias,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
size_t max_seqlen, size_t max_seqlen,
float attn_scale, float dropout, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -266,6 +318,24 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -266,6 +318,24 @@ void nvte_fused_attn_bwd_qkvpacked(
| 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | | 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
\endverbatim \endverbatim
* *
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens_q
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
*
* \param[in] Q The Q tensor, in HD layouts. * \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in 2HD or H2D layouts. * \param[in] KV The KV tensor, in 2HD or H2D layouts.
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
...@@ -273,8 +343,12 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -273,8 +343,12 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[out] O The output O tensor. * \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state. * e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
...@@ -298,6 +372,10 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -298,6 +372,10 @@ void nvte_fused_attn_fwd_kvpacked(
NVTETensorPack* Aux_CTX_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
const NVTETensor rng_state, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout, bool is_training, float attn_scale, float dropout,
...@@ -315,6 +393,24 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -315,6 +393,24 @@ void nvte_fused_attn_fwd_kvpacked(
| 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | | 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
\endverbatim \endverbatim
* *
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens_q
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
*
* \param[in] Q The Q tensor, in HD layouts. * \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in H2D or 2HD layouts. * \param[in] KV The KV tensor, in H2D or 2HD layouts.
* \param[in] O The O tensor from forward. * \param[in] O The O tensor from forward.
...@@ -326,8 +422,12 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -326,8 +422,12 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[out] dQ The gradient of the Q tensor. * \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor. * \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for KV. * \param[in] max_seqlen_kv Max sequence length used for computing for KV.
...@@ -353,6 +453,10 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -353,6 +453,10 @@ void nvte_fused_attn_bwd_kvpacked(
NVTETensor dBias, NVTETensor dBias,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -378,6 +482,34 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -378,6 +482,34 @@ void nvte_fused_attn_bwd_kvpacked(
| 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 | | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
\endverbatim \endverbatim
* *
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
qkv_group = nvte_get_qkv_layout_group(qkv_layout)
if qkv_group == 'hd_hd_hd':
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv
if qkv_group in ['3hd', 'h3d']:
seq_offsets_q = num_attn_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_k = num_attn_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_v = num_attn_heads * head_dim * 3 * cu_seqlens_q
if qkv_group in ['hd_2hd', 'hd_h2d']:
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens_q
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
*
* \param[in] Q The Q tensor. * \param[in] Q The Q tensor.
* \param[in] K The K tensor. * \param[in] K The K tensor.
* \param[in] V The V tensor. * \param[in] V The V tensor.
...@@ -388,6 +520,10 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -388,6 +520,10 @@ void nvte_fused_attn_bwd_kvpacked(
* e.g. M, ZInv, rng_state. * e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
...@@ -412,6 +548,10 @@ void nvte_fused_attn_fwd( ...@@ -412,6 +548,10 @@ void nvte_fused_attn_fwd(
NVTETensorPack* Aux_CTX_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
const NVTETensor rng_state, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout, bool is_training, float attn_scale, float dropout,
...@@ -432,6 +572,34 @@ void nvte_fused_attn_fwd( ...@@ -432,6 +572,34 @@ void nvte_fused_attn_fwd(
| 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 | | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
\endverbatim \endverbatim
* *
* Notes:
*
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o`
* help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are,
\verbatim
qkv_group = nvte_get_qkv_layout_group(qkv_layout)
if qkv_group == 'hd_hd_hd':
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv
if qkv_group in ['3hd', 'h3d']:
seq_offsets_q = num_attn_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_k = num_attn_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_v = num_attn_heads * head_dim * 3 * cu_seqlens_q
if qkv_group in ['hd_2hd', 'hd_h2d']:
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens_q
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`.
*
* \param[in] Q The Q tensor. * \param[in] Q The Q tensor.
* \param[in] K The K tensor. * \param[in] K The K tensor.
* \param[in] V The V tensor. * \param[in] V The V tensor.
...@@ -447,6 +615,10 @@ void nvte_fused_attn_fwd( ...@@ -447,6 +615,10 @@ void nvte_fused_attn_fwd(
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V. * \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
...@@ -474,6 +646,10 @@ void nvte_fused_attn_bwd( ...@@ -474,6 +646,10 @@ void nvte_fused_attn_bwd(
NVTETensor dBias, NVTETensor dBias,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......
...@@ -1239,27 +1239,37 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -1239,27 +1239,37 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
assert(q_max_seqlen == kv_max_seqlen); assert(q_max_seqlen == kv_max_seqlen);
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_rng_state_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
mask_type, query_workspace_tensor.data(), nullptr); dummy_ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(), nvte_fused_attn_fwd_kvpacked(
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
is_training, scaling_factor, dropout_probability, qkv_layout, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
bias_type, mask_type, query_workspace_tensor.data(), nullptr); dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr); query_workspace_tensor.data(), nullptr);
} else { } else {
NVTE_ERROR("Unsupported QKVLayout."); NVTE_ERROR("Unsupported QKVLayout.");
...@@ -1294,6 +1304,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -1294,6 +1304,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
assert(q_max_seqlen == kv_max_seqlen); assert(q_max_seqlen == kv_max_seqlen);
auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim};
...@@ -1304,8 +1316,10 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -1304,8 +1316,10 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
query_workspace_tensor.data(), nullptr); dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
...@@ -1319,9 +1333,12 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -1319,9 +1333,12 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
query_workspace_tensor.data(), nullptr); dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
...@@ -1336,11 +1353,15 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -1336,11 +1353,15 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
doutput_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), &aux_input_tensors,
dv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, dbias_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
query_workspace_tensor.data(), nullptr); dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
} else { } else {
NVTE_ERROR("Unsupported QKVLayout."); NVTE_ERROR("Unsupported QKVLayout.");
} }
...@@ -1416,6 +1437,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -1416,6 +1437,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size}, auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
descriptor.wkspace_dtype); descriptor.wkspace_dtype);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
/* Call the underly NVTE API */ /* Call the underly NVTE API */
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
auto qkv = buffers[0]; auto qkv = buffers[0];
...@@ -1423,9 +1446,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -1423,9 +1446,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), rng_state_tensor.data(), q_max_seqlen, &aux_output_tensors, q_cu_seqlens_tensor.data(),
descriptor.is_training, descriptor.scaling_factor, dropout_probability, qkv_layout, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
bias_type, mask_type, workspace_tensor.data(), stream); dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -1437,9 +1463,11 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -1437,9 +1463,11 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
workspace_tensor.data(), stream); rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -1453,9 +1481,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -1453,9 +1481,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, descriptor.is_training, scaling_factor,
bias_type, mask_type, workspace_tensor.data(), stream); dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
...@@ -1496,15 +1527,20 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -1496,15 +1527,20 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
nvte_tensor_pack_create(&aux_input_tensors); nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), &aux_input_tensors,
dbias_tensor.data(), q_cu_seqlens_tensor.data(), dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dbias_tensor.data(),
dropout_probability, qkv_layout, bias_type, mask_type, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
query_workspace_tensor.data(), nullptr); dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
...@@ -1574,6 +1610,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -1574,6 +1610,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto wkspace_dtype = descriptor.wkspace_dtype; auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
/* Call the underly NVTE API */ /* Call the underly NVTE API */
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
auto qkv = buffers[0]; auto qkv = buffers[0];
...@@ -1586,8 +1624,10 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -1586,8 +1624,10 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
workspace_tensor.data(), stream); dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -1605,9 +1645,11 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -1605,9 +1645,11 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
workspace_tensor.data(), stream); dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -1629,10 +1671,12 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -1629,10 +1671,12 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(),
dv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), dv_tensor.data(), dbias_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
workspace_tensor.data(), stream); dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
......
...@@ -647,10 +647,13 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -647,10 +647,13 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(),
max_seqlen, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream()); qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
...@@ -664,7 +667,9 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -664,7 +667,9 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
// execute the kernel // execute the kernel
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(),
max_seqlen, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream()); qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
...@@ -730,10 +735,13 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -730,10 +735,13 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), max_seqlen, attn_scale, p_dropout, te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream()); qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
// allocate memory for workspace // allocate memory for workspace
...@@ -743,7 +751,9 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -743,7 +751,9 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
// execute kernel // execute kernel
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), max_seqlen, attn_scale, p_dropout, te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream()); qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
// destroy tensor wrappers // destroy tensor wrappers
...@@ -816,10 +826,13 @@ void te_fused_attn_fwd_kvpacked( ...@@ -816,10 +826,13 @@ void te_fused_attn_fwd_kvpacked(
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
...@@ -834,7 +847,9 @@ void te_fused_attn_fwd_kvpacked( ...@@ -834,7 +847,9 @@ void te_fused_attn_fwd_kvpacked(
// execute the kernel // execute the kernel
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
...@@ -909,11 +924,14 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K ...@@ -909,11 +924,14 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// allocate memory for workspace // allocate memory for workspace
...@@ -924,7 +942,9 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K ...@@ -924,7 +942,9 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// destroy tensor wrappers // destroy tensor wrappers
...@@ -989,10 +1009,13 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -989,10 +1009,13 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream()); attn_mask_type_enum, workspace.data(), Q.stream());
...@@ -1008,7 +1031,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1008,7 +1031,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
// execute the kernel // execute the kernel
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream()); attn_mask_type_enum, workspace.data(), Q.stream());
...@@ -1084,11 +1109,14 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1084,11 +1109,14 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream()); Q.stream());
...@@ -1100,7 +1128,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1100,7 +1128,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream()); Q.stream());
......
...@@ -1683,8 +1683,6 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -1683,8 +1683,6 @@ class UnfusedDotProductAttention(torch.nn.Module):
assert (qkv_layout in QKVLayouts assert (qkv_layout in QKVLayouts
), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!" ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
assert (qkv_format != 'thd'
), """UnfusedDotProductAttention does not support variable sequence lengths!"""
if qkv_format == 'bshd': if qkv_format == 'bshd':
# convert to sbhd and use sbhd implementation for now # convert to sbhd and use sbhd implementation for now
query_layer, key_layer, value_layer = [x.transpose(0, 1) query_layer, key_layer, value_layer = [x.transpose(0, 1)
...@@ -2067,7 +2065,7 @@ class FlashAttention(torch.nn.Module): ...@@ -2067,7 +2065,7 @@ class FlashAttention(torch.nn.Module):
else: else:
query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous() query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
for x in (query_layer, key_layer, value_layer)] for x in (query_layer, key_layer, value_layer)]
elif qkv_format == 'bshd': elif qkv_format in ['bshd', 'thd']:
query_layer, key_layer, value_layer = [x.contiguous() query_layer, key_layer, value_layer = [x.contiguous()
for x in (query_layer, key_layer, value_layer)] for x in (query_layer, key_layer, value_layer)]
...@@ -2181,7 +2179,7 @@ class FlashAttention(torch.nn.Module): ...@@ -2181,7 +2179,7 @@ class FlashAttention(torch.nn.Module):
**fa_optional_forward_kwargs, **fa_optional_forward_kwargs,
) )
if 'padding' in attn_mask_type: if qkv_format in ['sbhd', 'bshd'] and 'padding' in attn_mask_type:
output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
if qkv_format == 'sbhd': if qkv_format == 'sbhd':
...@@ -2230,7 +2228,9 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2230,7 +2228,9 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
"""Function for FusedAttention with packed QKV input""" """Function for FusedAttention with packed QKV input"""
@staticmethod @staticmethod
def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale, def forward(ctx, is_training, max_seqlen, cu_seqlens,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
qkv, qkv_dtype, attn_bias, attn_scale,
dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen, fused_attention_backend, use_FAv2_bwd, rng_gen, fused_attention_backend, use_FAv2_bwd,
fp8, fp8_meta): fp8, fp8_meta):
...@@ -2257,6 +2257,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2257,6 +2257,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked( out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training, max_seqlen, cu_seqlens, is_training, max_seqlen, cu_seqlens,
qkv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, qkv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale[META_S],
...@@ -2297,6 +2298,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2297,6 +2298,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked( out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype,
fused_attention_backend, attn_bias, fused_attention_backend, attn_bias,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
None, None, None, None, None, None, None, None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen) rng_gen)
...@@ -2305,7 +2307,9 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2305,7 +2307,9 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens, *fp8_tensors, *aux_ctx_tensors) ctx.save_for_backward(*qkvo_tensors, cu_seqlens,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
*fp8_tensors, *aux_ctx_tensors)
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.max_seqlen = max_seqlen ctx.max_seqlen = max_seqlen
ctx.qkv_dtype = qkv_dtype ctx.qkv_dtype = qkv_dtype
...@@ -2330,7 +2334,9 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2330,7 +2334,9 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
d_out = d_out._data d_out = d_out._data
d_out = d_out.contiguous() d_out = d_out.contiguous()
(qkv, out, cu_seqlens, qkv_fp8, out_fp8, (qkv, out, cu_seqlens,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
qkv_fp8, out_fp8,
fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors
if not aux_ctx_tensors[0].is_contiguous(): if not aux_ctx_tensors[0].is_contiguous():
aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
...@@ -2369,6 +2375,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2369,6 +2375,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
qkv_fp8, out_fp8, d_out_fp8, qkv_fp8, out_fp8, d_out_fp8,
fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors, fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
fwd_scale_invs[META_QKV], # d_scale_qkv, fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s, fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o, fwd_scale_invs[META_O], # d_scale_o,
...@@ -2404,17 +2411,18 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2404,17 +2411,18 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.max_seqlen, cu_seqlens, qkv, out, d_out, ctx.max_seqlen, cu_seqlens, qkv, out, d_out,
ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors, ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias or alibi, return dqkv # if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]: if ctx.attn_bias_type in ["no_bias", "alibi"]:
return (None, None, None, dqkv, None, None, None, return (None, None, None, None, None, None, None, dqkv, None, None, None,
None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None, None, None) None, None, None, None, None, None)
# else, return (dqkv, dbias) # else, return (dqkv, dbias)
return (None, None, None, dqkv, None, rest[0], None, return (None, None, None, None, None, None, None, dqkv, None, rest[0], None,
None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None, None, None) None, None, None, None, None, None)
...@@ -2424,6 +2432,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -2424,6 +2432,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
use_FAv2_bwd, fp8, fp8_meta): use_FAv2_bwd, fp8, fp8_meta):
...@@ -2454,6 +2463,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -2454,6 +2463,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked( out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q_fp8, kv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, q_fp8, kv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale[META_S],
...@@ -2497,6 +2507,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -2497,6 +2507,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked( out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, qkv_dtype, fused_attention_backend, attn_bias, q, kv, qkv_dtype, fused_attention_backend, attn_bias,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
None, None, None, None, None, None, None, None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen) rng_gen)
...@@ -2506,6 +2517,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -2506,6 +2517,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
*fp8_tensors, *aux_ctx_tensors) *fp8_tensors, *aux_ctx_tensors)
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_q = max_seqlen_q
...@@ -2532,7 +2544,9 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -2532,7 +2544,9 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
d_out = d_out._data d_out = d_out._data
d_out = d_out.contiguous() d_out = d_out.contiguous()
(q, kv, out, cu_seqlens_q, cu_seqlens_kv, q_fp8, kv_fp8, out_fp8, (q, kv, out, cu_seqlens_q, cu_seqlens_kv,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
q_fp8, kv_fp8, out_fp8,
fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors
if not aux_ctx_tensors[0].is_contiguous(): if not aux_ctx_tensors[0].is_contiguous():
aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
...@@ -2573,6 +2587,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -2573,6 +2587,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
q_fp8, kv_fp8, out_fp8, d_out_fp8, q_fp8, kv_fp8, out_fp8, d_out_fp8,
fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors, fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
fwd_scale_invs[META_QKV], # d_scale_qkv, fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s, fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o, fwd_scale_invs[META_O], # d_scale_o,
...@@ -2620,17 +2635,18 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -2620,17 +2635,18 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
q, kv, out, d_out, q, kv, out, d_out,
ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors, ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias or alibi, return dqkv # if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]: if ctx.attn_bias_type in ["no_bias", "alibi"]:
return (None, None, None, None, None, dq, dkv, None, None, None, return (None, None, None, None, None, None, None, None, None, dq, dkv, None, None, None,
None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None, None, None) None, None, None, None, None, None)
# else, return (dqkv, dbias) # else, return (dqkv, dbias)
return (None, None, None, None, None, dq, dkv, None, rest[0], None, return (None, None, None, None, None, None, None, None, None, dq, dkv, None, rest[0], None,
None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None, None, None) None, None, None, None, None, None)
...@@ -2639,6 +2655,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -2639,6 +2655,7 @@ class FusedAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
use_FAv2_bwd, fp8, fp8_meta): use_FAv2_bwd, fp8, fp8_meta):
...@@ -2690,6 +2707,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -2690,6 +2707,7 @@ class FusedAttnFunc(torch.autograd.Function):
out_fp8, aux_ctx_tensors = fused_attn_fwd( out_fp8, aux_ctx_tensors = fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q_fp8, k_fp8, v_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, q_fp8, k_fp8, v_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale[META_S],
...@@ -2761,6 +2779,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -2761,6 +2779,7 @@ class FusedAttnFunc(torch.autograd.Function):
out_ret, aux_ctx_tensors = fused_attn_fwd( out_ret, aux_ctx_tensors = fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, k, v, qkv_dtype, fused_attention_backend, attn_bias, q, k, v, qkv_dtype, fused_attention_backend, attn_bias,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
None, None, None, None, None, None, None, None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen) rng_gen)
...@@ -2778,6 +2797,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -2778,6 +2797,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
*fp8_tensors, *aux_ctx_tensors) *fp8_tensors, *aux_ctx_tensors)
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_q = max_seqlen_q
...@@ -2804,7 +2824,9 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -2804,7 +2824,9 @@ class FusedAttnFunc(torch.autograd.Function):
d_out = d_out._data d_out = d_out._data
d_out = d_out.contiguous() d_out = d_out.contiguous()
(q, k, v, out, cu_seqlens_q, cu_seqlens_kv, q_fp8, k_fp8, v_fp8, out_fp8, (q, k, v, out, cu_seqlens_q, cu_seqlens_kv,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
q_fp8, k_fp8, v_fp8, out_fp8,
fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors
if not aux_ctx_tensors[0].is_contiguous(): if not aux_ctx_tensors[0].is_contiguous():
aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
...@@ -2846,6 +2868,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -2846,6 +2868,7 @@ class FusedAttnFunc(torch.autograd.Function):
q_fp8, k_fp8, v_fp8, out_fp8, d_out_fp8, q_fp8, k_fp8, v_fp8, out_fp8, d_out_fp8,
fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors, fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
fwd_scale_invs[META_QKV], # d_scale_qkv, fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s, fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o, fwd_scale_invs[META_O], # d_scale_o,
...@@ -2929,17 +2952,20 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -2929,17 +2952,20 @@ class FusedAttnFunc(torch.autograd.Function):
q, k, v, out, d_out, q, k, v, out, d_out,
ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors, ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias or alibi, return dqkv # if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]: if ctx.attn_bias_type in ["no_bias", "alibi"]:
return (None, None, None, None, None, dq, dk, dv, None, None, None, return (None, None, None, None, None, None,
None, None, None, dq, dk, dv, None, None, None,
None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None, None, None) None, None, None, None, None, None)
# else, return (dqkv, dbias) # else, return (dqkv, dbias)
return (None, None, None, None, None, dq, dk, dv, None, rest[0], None, return (None, None, None, None, None, None,
None, None, None, dq, dk, dv, None, rest[0], None,
None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None, None, None) None, None, None, None, None, None)
...@@ -3032,6 +3058,10 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -3032,6 +3058,10 @@ class FusedAttention(TransformerEngineBaseModule):
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
seq_offsets_q: Optional[torch.Tensor] = None,
seq_offsets_k: Optional[torch.Tensor] = None,
seq_offsets_v: Optional[torch.Tensor] = None,
seq_offsets_o: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None, max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: Optional[int] = None,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
...@@ -3047,7 +3077,6 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -3047,7 +3077,6 @@ class FusedAttention(TransformerEngineBaseModule):
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""fused attention fprop""" """fused attention fprop"""
assert (fused_attention_backend assert (fused_attention_backend
!= tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
), 'No fused attention backend supports this input combination!' ), 'No fused attention backend supports this input combination!'
...@@ -3066,9 +3095,6 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -3066,9 +3095,6 @@ class FusedAttention(TransformerEngineBaseModule):
context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1) context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
assert (
qkv_format != 'thd'
), 'FusedAttention does not support qkv_format = thd!'
if qkv_format in ['sbhd', 'bshd']: if qkv_format in ['sbhd', 'bshd']:
if qkv_format == 'sbhd': if qkv_format == 'sbhd':
...@@ -3104,6 +3130,34 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -3104,6 +3130,34 @@ class FusedAttention(TransformerEngineBaseModule):
max_seqlen_kv, max_seqlen_kv,
key_layer.device, key_layer.device,
) )
if qkv_format == 'thd':
assert not context_parallel, "thd format not supported with context parallelism!"
assert (max_seqlen_q is not None
and max_seqlen_kv is not None
and cu_seqlens_q is not None
and cu_seqlens_kv is not None
), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
if (seq_offsets_q is None
or seq_offsets_k is None
or seq_offsets_v is None
or seq_offsets_o is None):
qkv_group = ''.join([x for x in qkv_layout if x not in 'bst'])
num_heads = query_layer.shape[-2]
num_gqa_groups = key_layer.shape[-2]
head_dim = query_layer.shape[-1]
seq_offsets_o = num_heads * head_dim * cu_seqlens_q
if qkv_group == 'hd_hd_hd':
seq_offsets_q = num_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv
if qkv_group in ['3hd', 'h3d']:
seq_offsets_q = num_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_k = num_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_v = num_heads * head_dim * 3 * cu_seqlens_q
if qkv_group in ['hd_2hd', 'hd_h2d']:
seq_offsets_q = num_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
qkv_dtype = TE_DType[query_layer.dtype] qkv_dtype = TE_DType[query_layer.dtype]
...@@ -3157,6 +3211,7 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -3157,6 +3211,7 @@ class FusedAttention(TransformerEngineBaseModule):
self.training, self.training,
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q, cu_seqlens_kv,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
query_layer, key_layer, value_layer, query_layer, key_layer, value_layer,
qkv_dtype, qkv_dtype,
core_attention_bias, core_attention_bias,
...@@ -3430,6 +3485,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -3430,6 +3485,10 @@ class DotProductAttention(torch.nn.Module):
qkv_format: Optional[str] = None, qkv_format: Optional[str] = None,
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
seq_offsets_q: Optional[torch.Tensor] = None,
seq_offsets_k: Optional[torch.Tensor] = None,
seq_offsets_v: Optional[torch.Tensor] = None,
seq_offsets_o: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None, max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: Optional[int] = None,
attn_mask_type: Optional[str] = None, attn_mask_type: Optional[str] = None,
...@@ -3511,6 +3570,18 @@ class DotProductAttention(torch.nn.Module): ...@@ -3511,6 +3570,18 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_kv: Optional[torch.Tensor], default = `None` cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`, Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
with shape [batch_size + 1] and dtype torch.int32. with shape [batch_size + 1] and dtype torch.int32.
seq_offsets_q: Optional[torch.Tensor], default = `None`
Cumulative offset of different sequences in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
seq_offsets_k: Optional[torch.Tensor], default = `None`
Cumulative offset of different sequences in a batch for `key_layer`,
with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
seq_offsets_v: Optional[torch.Tensor], default = `None`
Cumulative offset of different sequences in a batch for `value_layer`,
with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
seq_offsets_o: Optional[torch.Tensor], default = `None`
Cumulative offset of different sequences in a batch for forward output,
with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
max_seqlen_q: Optional[int], default = `None` max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`. Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided. Calculated from `cu_seqlens_q` if not provided.
...@@ -3581,6 +3652,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -3581,6 +3652,9 @@ class DotProductAttention(torch.nn.Module):
assert (attn_mask_type in AttnMaskTypes assert (attn_mask_type in AttnMaskTypes
), f"Attention mask type {attn_mask_type} is not supported!" ), f"Attention mask type {attn_mask_type} is not supported!"
if qkv_format == 'thd':
assert ('padding' in attn_mask_type
), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
if self.rng_states_tracker is not None and is_graph_capturing(): if self.rng_states_tracker is not None and is_graph_capturing():
assert ( assert (
...@@ -3649,10 +3723,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -3649,10 +3723,10 @@ class DotProductAttention(torch.nn.Module):
), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
if max_seqlen_q is None: if max_seqlen_q is None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_q = seqlens_q.max().item() max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item())))
if max_seqlen_kv is None: if max_seqlen_kv is None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = seqlens_kv.max().item() max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item())))
if qkv_format in ['sbhd', 'bshd']: if qkv_format in ['sbhd', 'bshd']:
assert (all(len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)) assert (all(len(x.shape) == 4 for x in (query_layer, key_layer, value_layer))
...@@ -3690,6 +3764,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -3690,6 +3764,10 @@ class DotProductAttention(torch.nn.Module):
# The following section filters out some backends based on # The following section filters out some backends based on
# certain asserts before executing the forward pass. # certain asserts before executing the forward pass.
# Filter: QKV layout.
if qkv_format == 'thd':
use_unfused_attention = False
# Filter: ONNX export. # Filter: ONNX export.
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
use_flash_attention = False use_flash_attention = False
...@@ -3891,6 +3969,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -3891,6 +3969,10 @@ class DotProductAttention(torch.nn.Module):
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
...@@ -3910,6 +3992,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -3910,6 +3992,10 @@ class DotProductAttention(torch.nn.Module):
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
......
...@@ -83,6 +83,10 @@ def fused_attn_fwd_qkvpacked( ...@@ -83,6 +83,10 @@ def fused_attn_fwd_qkvpacked(
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
...@@ -118,6 +122,14 @@ def fused_attn_fwd_qkvpacked( ...@@ -118,6 +122,14 @@ def fused_attn_fwd_qkvpacked(
attn_bias: torch.Tensor, default = None attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -225,6 +237,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -225,6 +237,7 @@ def fused_attn_fwd_qkvpacked(
max_seqlen, is_training, attn_scale, dropout, fast_zero_fill, max_seqlen, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, qkv_dtype, cu_seqlens, qkv, qkv_dtype,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias,
rng_gen, rng_elts_per_thread, rng_gen, rng_elts_per_thread,
) )
...@@ -243,6 +256,10 @@ def fused_attn_bwd_qkvpacked( ...@@ -243,6 +256,10 @@ def fused_attn_bwd_qkvpacked(
dqkv_dtype: tex.DType, dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
...@@ -286,6 +303,14 @@ def fused_attn_bwd_qkvpacked( ...@@ -286,6 +303,14 @@ def fused_attn_bwd_qkvpacked(
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends. please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -361,6 +386,7 @@ def fused_attn_bwd_qkvpacked( ...@@ -361,6 +386,7 @@ def fused_attn_bwd_qkvpacked(
max_seqlen, attn_scale, dropout, fast_zero_fill, max_seqlen, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, cu_seqlens, qkv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
) )
...@@ -379,6 +405,10 @@ def fused_attn_fwd_kvpacked( ...@@ -379,6 +405,10 @@ def fused_attn_fwd_kvpacked(
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
...@@ -421,6 +451,14 @@ def fused_attn_fwd_kvpacked( ...@@ -421,6 +451,14 @@ def fused_attn_fwd_kvpacked(
attn_bias: torch.Tensor, default = None attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -529,6 +567,7 @@ def fused_attn_fwd_kvpacked( ...@@ -529,6 +567,7 @@ def fused_attn_fwd_kvpacked(
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype, cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o,
attn_bias, rng_gen, rng_elts_per_thread, attn_bias, rng_gen, rng_elts_per_thread,
) )
...@@ -550,6 +589,10 @@ def fused_attn_bwd_kvpacked( ...@@ -550,6 +589,10 @@ def fused_attn_bwd_kvpacked(
dqkv_dtype: tex.DType, dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
...@@ -600,6 +643,14 @@ def fused_attn_bwd_kvpacked( ...@@ -600,6 +643,14 @@ def fused_attn_bwd_kvpacked(
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends. please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -679,6 +730,7 @@ def fused_attn_bwd_kvpacked( ...@@ -679,6 +730,7 @@ def fused_attn_bwd_kvpacked(
max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
) )
...@@ -698,6 +750,10 @@ def fused_attn_fwd( ...@@ -698,6 +750,10 @@ def fused_attn_fwd(
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
...@@ -744,6 +800,14 @@ def fused_attn_fwd( ...@@ -744,6 +800,14 @@ def fused_attn_fwd(
attn_bias: torch.Tensor, default = None attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -854,6 +918,7 @@ def fused_attn_fwd( ...@@ -854,6 +918,7 @@ def fused_attn_fwd(
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype, cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o,
attn_bias, rng_gen, rng_elts_per_thread, attn_bias, rng_gen, rng_elts_per_thread,
) )
...@@ -876,6 +941,10 @@ def fused_attn_bwd( ...@@ -876,6 +941,10 @@ def fused_attn_bwd(
dqkv_dtype: tex.DType, dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
...@@ -929,6 +998,14 @@ def fused_attn_bwd( ...@@ -929,6 +998,14 @@ def fused_attn_bwd(
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends. please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -1012,6 +1089,7 @@ def fused_attn_bwd( ...@@ -1012,6 +1089,7 @@ def fused_attn_bwd(
max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
) )
......
...@@ -31,6 +31,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -31,6 +31,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
const at::Tensor cu_seqlens, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor QKV,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
...@@ -54,6 +58,10 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -54,6 +58,10 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
...@@ -76,6 +84,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -76,6 +84,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
const at::Tensor Q, const at::Tensor Q,
const at::Tensor KV, const at::Tensor KV,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
...@@ -101,6 +113,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -101,6 +113,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
...@@ -124,6 +140,10 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -124,6 +140,10 @@ std::vector<at::Tensor> fused_attn_fwd(
const at::Tensor K, const at::Tensor K,
const at::Tensor V, const at::Tensor V,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
...@@ -150,6 +170,10 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -150,6 +170,10 @@ std::vector<at::Tensor> fused_attn_bwd(
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
......
...@@ -96,6 +96,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -96,6 +96,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
const at::Tensor cu_seqlens, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor QKV,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
...@@ -123,6 +127,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -123,6 +127,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens; TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens;
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
auto h = q_shape[q_shape.size() - 2]; auto h = q_shape[q_shape.size() - 2];
...@@ -169,6 +174,32 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -169,6 +174,32 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape, te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
if ((seq_offsets_q.has_value())
&& (seq_offsets_k.has_value())
&& (seq_offsets_v.has_value())
&& (seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{
seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{
seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{
seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{
seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(),
seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(),
seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(),
seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(),
seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr);
}
// extract random number generator seed and offset // extract random number generator seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
...@@ -193,6 +224,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -193,6 +224,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
te_O.data(), te_O.data(),
&nvte_aux_tensor_pack, &nvte_aux_tensor_pack,
te_cu_seqlens.data(), te_cu_seqlens.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
te_rng_state.data(), te_rng_state.data(),
max_seqlen, max_seqlen,
is_training, attn_scale, p_dropout, is_training, attn_scale, p_dropout,
...@@ -241,6 +276,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -241,6 +276,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
te_O.data(), te_O.data(),
&nvte_aux_tensor_pack, &nvte_aux_tensor_pack,
te_cu_seqlens.data(), te_cu_seqlens.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
te_rng_state.data(), te_rng_state.data(),
max_seqlen, max_seqlen,
is_training, attn_scale, p_dropout, is_training, attn_scale, p_dropout,
...@@ -266,6 +305,10 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -266,6 +305,10 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
...@@ -380,6 +423,33 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -380,6 +423,33 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
TensorWrapper te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape, TensorWrapper te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
if ((seq_offsets_q.has_value())
&& (seq_offsets_k.has_value())
&& (seq_offsets_v.has_value())
&& (seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{
seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{
seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{
seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{
seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(),
seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(),
seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(),
seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(),
seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr);
}
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
...@@ -394,6 +464,10 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -394,6 +464,10 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
te_dQKV.data(), te_dQKV.data(),
te_dBias.data(), te_dBias.data(),
te_cu_seqlens.data(), te_cu_seqlens.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
max_seqlen, max_seqlen,
attn_scale, p_dropout, attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type, qkv_layout, bias_type, attn_mask_type,
...@@ -417,6 +491,10 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -417,6 +491,10 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
te_dQKV.data(), te_dQKV.data(),
te_dBias.data(), te_dBias.data(),
te_cu_seqlens.data(), te_cu_seqlens.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
max_seqlen, max_seqlen,
attn_scale, p_dropout, attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type, qkv_layout, bias_type, attn_mask_type,
...@@ -439,6 +517,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -439,6 +517,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
const at::Tensor Q, const at::Tensor Q,
const at::Tensor KV, const at::Tensor KV,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
...@@ -462,6 +544,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -462,6 +544,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv;
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
auto h = q_shape[q_shape.size() - 2]; auto h = q_shape[q_shape.size() - 2];
...@@ -516,6 +599,32 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -516,6 +599,32 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
if ((seq_offsets_q.has_value())
&& (seq_offsets_k.has_value())
&& (seq_offsets_v.has_value())
&& (seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{
seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{
seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{
seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{
seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(),
seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(),
seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(),
seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(),
seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr);
}
// extract rng seed and offset // extract rng seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
...@@ -542,6 +651,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -542,6 +651,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
&nvte_aux_tensor_pack, &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
te_rng_state.data(), te_rng_state.data(),
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, is_training, attn_scale, p_dropout,
...@@ -592,6 +705,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -592,6 +705,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
&nvte_aux_tensor_pack, &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
te_rng_state.data(), te_rng_state.data(),
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, is_training, attn_scale, p_dropout,
...@@ -620,6 +737,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -620,6 +737,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
...@@ -725,6 +846,33 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -725,6 +846,33 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
if ((seq_offsets_q.has_value())
&& (seq_offsets_k.has_value())
&& (seq_offsets_v.has_value())
&& (seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{
seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{
seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{
seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{
seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(),
seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(),
seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(),
seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(),
seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr);
}
// convert auxiliary tensors from forward to NVTETensors // convert auxiliary tensors from forward to NVTETensors
NVTETensorPack nvte_aux_tensor_pack; NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_tensor_pack_create(&nvte_aux_tensor_pack);
...@@ -771,6 +919,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -771,6 +919,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_dBias.data(), te_dBias.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout, attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type, qkv_layout, bias_type, attn_mask_type,
...@@ -797,6 +949,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -797,6 +949,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_dBias.data(), te_dBias.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout, attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type, qkv_layout, bias_type, attn_mask_type,
...@@ -820,6 +976,10 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -820,6 +976,10 @@ std::vector<at::Tensor> fused_attn_fwd(
const at::Tensor K, const at::Tensor K,
const at::Tensor V, const at::Tensor V,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
...@@ -844,6 +1004,7 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -844,6 +1004,7 @@ std::vector<at::Tensor> fused_attn_fwd(
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias; TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias;
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
auto h = q_shape[q_shape.size() - 2]; auto h = q_shape[q_shape.size() - 2];
...@@ -902,6 +1063,32 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -902,6 +1063,32 @@ std::vector<at::Tensor> fused_attn_fwd(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
if ((seq_offsets_q.has_value())
&& (seq_offsets_k.has_value())
&& (seq_offsets_v.has_value())
&& (seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{
seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{
seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{
seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{
seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(),
seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(),
seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(),
seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(),
seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr);
}
// extract rng seed and offset // extract rng seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
...@@ -930,6 +1117,10 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -930,6 +1117,10 @@ std::vector<at::Tensor> fused_attn_fwd(
&nvte_aux_tensor_pack, &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
te_rng_state.data(), te_rng_state.data(),
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, is_training, attn_scale, p_dropout,
...@@ -981,6 +1172,10 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -981,6 +1172,10 @@ std::vector<at::Tensor> fused_attn_fwd(
&nvte_aux_tensor_pack, &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
te_rng_state.data(), te_rng_state.data(),
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, is_training, attn_scale, p_dropout,
...@@ -1010,6 +1205,10 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1010,6 +1205,10 @@ std::vector<at::Tensor> fused_attn_bwd(
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
...@@ -1183,6 +1382,33 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1183,6 +1382,33 @@ std::vector<at::Tensor> fused_attn_bwd(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
if ((seq_offsets_q.has_value())
&& (seq_offsets_k.has_value())
&& (seq_offsets_v.has_value())
&& (seq_offsets_o.has_value())) {
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec();
std::vector<size_t> seq_offsets_q_shape{
seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{
seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()};
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{
seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{
seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(),
seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(),
seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(),
seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(),
seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr);
}
// convert auxiliary tensors from forward to NVTETensors // convert auxiliary tensors from forward to NVTETensors
NVTETensorPack nvte_aux_tensor_pack; NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_tensor_pack_create(&nvte_aux_tensor_pack);
...@@ -1231,6 +1457,10 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1231,6 +1457,10 @@ std::vector<at::Tensor> fused_attn_bwd(
te_dBias.data(), te_dBias.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout, attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type, qkv_layout, bias_type, attn_mask_type,
...@@ -1259,6 +1489,10 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1259,6 +1489,10 @@ std::vector<at::Tensor> fused_attn_bwd(
te_dBias.data(), te_dBias.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(),
te_seq_offsets_k.data(),
te_seq_offsets_v.data(),
te_seq_offsets_o.data(),
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout, attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type, qkv_layout, bias_type, attn_mask_type,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment