"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "4cefa9b49b6cb2be6d7eac88315df65e0f0d8c9a"
Commit e52bdb41 authored by Sudhakar Singh's avatar Sudhakar Singh Committed by Kshitij Janardan Lakhani
Browse files

Enable SWA with CP for THD input format (#2220)



* Add support for THD+CP+SWA through A2A comms
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* unblock the `padding`+`THD`+`CP(A2A)` with SWA case in A2A forward
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add proper support for thd
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* bug fix
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* enable thd+cp tests as essential
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add cp+thd+a2a test to essential
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix comments from greptile
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add proper skip for flash attention
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix the test to create separate tensors for flash and fused attention backend scenarios
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove redundant compare
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* simplify code
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add note for cu_seqlens_kv and cu_seqlens_kv_padded
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update tests/pytorch/attention/test_attention_with_cp.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fixo
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix docs
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix the argument name
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent 353a8eea
......@@ -89,40 +89,47 @@ def generate_input_shapes(
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "thd":
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
cu_seqlens_q_padded = torch.cat(
[
torch.zeros([1], dtype=torch.int32),
seqlens_q_padded.cumsum(0, dtype=torch.int32),
]
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
# Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does,
# cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only.
if kernel_backend == "FusedAttention":
cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()
# NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded`
# will not be the same as `cu_seqlens_q` and `cu_seqlens_q_padded` respectively.
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
total_tokens = cu_seqlens_q_padded[-1]
q_input_shape = (
config.batch_size * config.max_seqlen_q,
total_tokens,
config.num_heads,
config.head_dim_qk,
)
k_input_shape = (
config.batch_size * config.max_seqlen_q,
total_tokens,
config.num_gqa_groups,
config.head_dim_qk,
)
v_input_shape = (
config.batch_size * config.max_seqlen_q,
total_tokens,
config.num_gqa_groups,
config.head_dim_v,
)
attn_output_shape = (
config.batch_size * config.max_seqlen_q,
total_tokens,
config.num_heads * config.head_dim_v,
)
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
cu_seqlens_q_padded = torch.cat(
[
torch.zeros([1], dtype=torch.int32),
seqlens_q_padded.cumsum(0, dtype=torch.int32),
torch.tensor([q_input_shape[0]], dtype=torch.int32),
]
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
if kernel_backend == "FusedAttention":
cu_seqlens_q[1:-1] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()
cu_seqlens_q[-1] = cu_seqlens_q[-2]
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
else:
assert False, f"{qkv_format=} is not supported!"
......
......@@ -7,7 +7,7 @@ import subprocess
import sys
import pathlib
import logging
import copy
import pytest
import torch
from transformer_engine.pytorch import (
......@@ -73,7 +73,7 @@ dtypes = ["bf16", "fp16"]
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_2_1", "cp_3_2", "cp_3_3"]
configs = ["cp_1_0", "cp_1_2", "cp_2_1", "cp_3_2", "cp_3_3"]
model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs}
dtypes = ["bf16"]
qkv_formats = ["sbhd", "thd"]
......@@ -96,12 +96,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if "a2a" in cp_comm_type and qkv_format == "thd":
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if qkv_format == "thd":
if cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
......@@ -183,7 +187,7 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"]
......@@ -224,10 +228,14 @@ def test_cp_with_fused_attention(
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and "a2a" in cp_comm_type:
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if qkv_format == "thd":
if cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)
if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
......@@ -281,6 +289,14 @@ def test_cp_with_fused_attention(
)
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
if qkv_format == "thd":
config = copy.deepcopy(config)
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
fp8_meta = {}
fp8_meta["recipe"] = None
fp8_meta["local_recipes"] = []
......
......@@ -4,6 +4,7 @@
"""Context Parallelism."""
import os
import itertools
from typing import List, Union, Tuple
import torch
import transformer_engine_torch as tex
......@@ -260,6 +261,146 @@ def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size
return x
def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim=0):
"""
Reorder sequence chunks for A2A communication that happens after attention
compute.
Args:
x: The input tensor to be reordered.
cu_seqlens: The cumulative sequence lengths of the input tensor.
cp_size: The number of ranks participating in context parallelism.
seq_dim: The dimension in which to reorder.
Returns:
The reordered tensor.
Example:
x: [ 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5.,
6., 7., 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3.,
4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.]
cu_seqlens: [ 0, 8, 16, 24, 40]
cp_size: 4
Returns: [ 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., 1., 6., 1., 6.,
1., 6., 2., 3., 12., 13., 2., 5., 2., 5., 2., 5., 4., 5.,
10., 11., 3., 4., 3., 4., 3., 4., 6., 7., 8., 9.]
This logic is similar to how the DualChunking is done to split the sequence
for each rank. Here, the indices of sequence chunks for all those ranks
are concatenated together. So the returned tensor ends up looking like as if
the chunks from all the ranks are concatenated together.
e.g. [
0., 7., 0., 7., 0., 7., 0., 1., 14., 15., # chunk on rank 0
1., 6., 1., 6., 1., 6., 2., 3., 12., 13., # chunk on rank 1
2., 5., 2., 5., 2., 5., 4., 5., 10., 11., # chunk on rank 2
3., 4., 3., 4., 3., 4., 6., 7., 8., 9. # chunk on rank 3
]
"""
total_slices_of_any_sequence = 2 * cp_size
slice_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]) // total_slices_of_any_sequence
indices = [
(
# 1st segment
torch.arange(
seq_start + (cp_rank * slice_size),
seq_start + ((cp_rank + 1) * slice_size),
device=cu_seqlens.device,
),
# 2nd segment
torch.arange(
seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size),
seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size),
device=cu_seqlens.device,
),
)
for cp_rank in range(cp_size)
for slice_size, seq_start in zip(slice_sizes, cu_seqlens[:-1])
]
# flatten the list of tuples to a list
indices = list(itertools.chain(*indices))
indices = torch.cat(indices)
return x.index_select(seq_dim, indices)
def reorder_seq_chunks_after_a2a_before_attn_thd(x, cu_seqlens, seq_chunk_ids, cp_size, seq_dim=0):
"""
Reorder sequence chunks for A2A communication that happens before attention
compute.
Args:
x: The input tensor to be reordered.
cu_seqlens: The cumulative sequence lengths of the input tensor.
seq_chunk_ids: The sequence chunk ids of the input `x` which is to be reordered.
cp_size: The number of ranks participating in context parallelism.
seq_dim: The dimension in which to reorder.
Returns:
The reordered tensor.
Example:
x: [ 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., 1., 6., 1., 6.,
1., 6., 2., 3., 12., 13., 2., 5., 2., 5., 2., 5., 4., 5.,
10., 11., 3., 4., 3., 4., 3., 4., 6., 7., 8., 9.]
cu_seqlens: [ 0, 8, 16, 24, 40]
seq_chunk_ids: [ 0, 2, 4, 6, 7, 5, 3, 1]
cp_size: 4
Returns: [ 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5.,
6., 7., 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3.,
4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.]
Note that the input sequences (x) are arranged after A2A communication as if DualChunked
chunks on all the ranks are concatenated together in the `seq_dim`.
e.g. [
0., 7., 0., 7., 0., 7., 0., 1., 14., 15., # chunk on rank 0
1., 6., 1., 6., 1., 6., 2., 3., 12., 13., # chunk on rank 1
2., 5., 2., 5., 2., 5., 4., 5., 10., 11., # chunk on rank 2
3., 4., 3., 4., 3., 4., 6., 7., 8., 9. # chunk on rank 3
]
Then the logic to serialize the sequences is:
1. For every sequence segment on any rank (denoted by `start` and `end`):
1a. For every chunk (in `chunk_id` and the total of those are twice as many as the number of CP ranks) :
1aa. The first `cp_size` number of chunks form the first half of the whole sequence. Get those indices.
1ab. The second `cp_size` number of chunks form the second half of the whole sequence. Get those indices.
1b. Concatenate the indices of the first half and the second half.
2. Reorder the entire input tensor by those indices.
"""
max_cum_seqlen_per_cp_rank = cu_seqlens[-1] // cp_size
cu_seqlens_on_any_cp_rank = cu_seqlens // cp_size
# Go through all the sequence segments (the sizes should be the same from all the ranks)
indices = [
torch.arange(
# Calculate 'left' boundary
(
start + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
if loc < cp_size
else (start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
),
# Calculate 'right' boundary
(
(start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
if loc < cp_size
else end + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
),
device=cu_seqlens.device,
)
for start, end in zip(cu_seqlens_on_any_cp_rank[:-1], cu_seqlens_on_any_cp_rank[1:])
for loc, chunk_id in enumerate(seq_chunk_ids)
]
indices = torch.cat(indices)
return x.index_select(seq_dim, indices)
def flash_attn_a2a_communicate(
a2a_inputs: Union[torch.Tensor, List[torch.Tensor]],
chunk_ids_for_a2a: torch.Tensor,
......@@ -268,8 +409,14 @@ def flash_attn_a2a_communicate(
cp_group: dist_group_type,
cp_stream: torch.cuda.Stream,
before_attn: bool,
qkv_format: str = "bshd",
cu_seqlens_padded: torch.Tensor = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""A2A communication for context parallelism."""
assert (
qkv_format != "thd" or cu_seqlens_padded is not None
), "cu_seqlens_padded is required for THD format!"
a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs
a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs)
if before_attn:
......@@ -283,20 +430,33 @@ def flash_attn_a2a_communicate(
with torch.cuda.stream(cp_stream):
a2a_reqs[i - 2].wait()
x = a2a_outputs[i - 2]
# reorder the sequence chunks
x = reorder_seq_chunks_for_a2a_before_attn(
x, chunk_ids_for_a2a, seq_dim, cp_size
)
# [b, cp*2, s//2, h//cp, d] -> [b, cp*s, h//cp, d]
# or [cp*2, s//2, b, h//cp, d] -> [cp*s, b, h//cp, d]
a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
if qkv_format in ["bshd", "sbhd"]:
# reorder the sequence chunks
x = reorder_seq_chunks_for_a2a_before_attn(
x, chunk_ids_for_a2a, seq_dim, cp_size
)
# [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn]
# or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
a2a_outputs[i - 2] = x.view(
*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]
)
else: # qkv_format == "thd"
# [cp, t, np//cp, hn] -> [cp*t, np//cp, hn]
x = x.view(-1, *x.shape[2:])
# reorder the sequence chunks
a2a_outputs[i - 2] = reorder_seq_chunks_after_a2a_before_attn_thd(
x, cu_seqlens_padded, chunk_ids_for_a2a, cp_size
)
if i < len(a2a_inputs):
x = a2a_inputs[i]
# [b, s, h, d] -> [b, s, cp, h//cp, d]
# or [s, b, h, d] -> [s, b, cp, h//cp, d]
# [b, s, np, hn] -> [b, s, cp, np//cp, hn]
# or [s, b, np, hn] -> [s, b, cp, np//cp, hn]
# or [t, np, hn] -> [t, cp, np//cp, hn]
x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1])
# [b, s, cp, h//cp, d] -> [cp, b, s, h//cp, d]
# or [s, b, cp, h//cp, d] -> [cp, s, b, h//cp, d]
# [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn]
# or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn]
# or [t, cp, np//cp, hn] -> [cp, t, np//cp, hn]
a2a_inputs[i] = x.movedim(-3, 0).contiguous()
else:
for i in range(len(a2a_inputs) + 2):
......@@ -307,22 +467,30 @@ def flash_attn_a2a_communicate(
)
if i < len(a2a_inputs):
x = a2a_inputs[i]
# [b, cp*s, h//cp, d] -> [b, cp*2, s//2, h//cp, d]
# or [cp*s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d]
x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :])
# reorder the sequence chunks
a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn(
x, chunk_ids_for_a2a, seq_dim, cp_size
)
if qkv_format in ["bshd", "sbhd"]:
# [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
# or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :])
# reorder the sequence chunks
a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn(
x, chunk_ids_for_a2a, seq_dim, cp_size
)
else: # qkv_format == "thd"
# reorder the sequence chunks
x = reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens_padded, cp_size)
# [cp*t, np//cp, hn] -> [cp, t, np//cp, hn]
a2a_inputs[i] = x.view(cp_size, -1, *x.shape[-2:])
if i > 1:
with torch.cuda.stream(cp_stream):
a2a_reqs[i - 2].wait()
x = a2a_outputs[i - 2]
# [cp, 2, b, s//2, h//cp, d] -> [b, 2, s//2, cp, h//cp, d]
# or [cp, 2, s//2, b, h//cp, d] -> [2, s//2, b, cp, h//cp, d]
# [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn]
# or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn]
# or [cp, t, np//cp, hn] -> [t, cp, np//cp, hn]
x = x.movedim(0, -3).movedim(0, seq_dim).contiguous()
# [b, 2, s//2, cp, h//cp, d] -> [b*s, h, d]
# or [2, s//2, b, cp, h//cp, d] -> [s*b, h, d]
# [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn]
# or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn]
# or [t, cp, np//cp, hn] -> [t, np, hn]
a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1])
torch.cuda.current_stream().wait_stream(cp_stream)
return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs
......@@ -3145,7 +3313,9 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
causal = "causal" in attn_mask_type
padding = "padding" in attn_mask_type
assert not padding, f"{attn_mask_type} mask type is not supported!"
assert (
not padding or qkv_format == "thd"
), f"{attn_mask_type} mask type is not supported for BSHD and SBHD!"
assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
assert (
......@@ -3196,11 +3366,14 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0
), "The number of attention heads needs to be divisible by CP size!"
assert qkv_format != "thd", f"{qkv_format} format is not supported!"
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
batch_dim = qkv_format.index("b")
seq_dim = qkv_format.index("s")
if qkv_format in ["bshd", "sbhd"]:
batch_dim = qkv_format.index("b")
seq_dim = qkv_format.index("s")
else: # qkv_format == "thd"
batch_dim = seq_dim = qkv_format.index("t")
assert (
q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
), "Sequence length per GPU needs to be divisible by 2!"
......@@ -3246,7 +3419,15 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device)
q, k, v = flash_attn_a2a_communicate(
[q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
[q, k, v],
chunk_ids_for_a2a,
seq_dim,
cp_size,
cp_group,
cp_stream,
before_attn=True,
qkv_format=qkv_format,
cu_seqlens_padded=cu_seqlens_q_padded,
)
if softmax_type != "vanilla":
softmax_offset = flash_attn_a2a_communicate_softmax_offset(
......@@ -3337,7 +3518,15 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out_.device)
out_ = flash_attn_a2a_communicate(
out_, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False
out_,
chunk_ids_for_a2a,
seq_dim,
cp_size,
cp_group,
cp_stream,
before_attn=False,
qkv_format=qkv_format,
cu_seqlens_padded=cu_seqlens_q_padded,
)
if return_max_logit:
max_logit = flash_attn_a2a_communicate_softmax_offset(
......@@ -3454,9 +3643,15 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cu_seqlens_kv_padded,
*aux_ctx_tensors,
) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
qkv_format = ctx.qkv_format
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
causal = "causal" in ctx.attn_mask_type
seq_dim = ctx.qkv_format.index("s")
if qkv_format in ["bshd", "sbhd"]:
seq_dim = qkv_format.index("s")
else: # qkv_format == "thd"
seq_dim = qkv_format.index("t")
bwd_nominal_dtype = ctx.fwd_nominal_dtype
dqkv_te_dtype = None
......@@ -3486,14 +3681,23 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
if not ctx.use_fused_attention:
out = out.view(ctx.batch_size, -1, *out.shape[-2:])
dout = dout.view(ctx.batch_size, -1, *dout.shape[-2:])
if qkv_format in ["bshd", "sbhd"]:
out = out.view(ctx.batch_size, -1, *out.shape[-2:])
dout = dout.view(ctx.batch_size, -1, *dout.shape[-2:])
else:
dout = dout.view(*ctx.out_shape)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, dout.device)
dout = flash_attn_a2a_communicate(
dout, chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
dout,
chunk_ids_for_a2a,
seq_dim,
cp_size,
ctx.cp_group,
ctx.cp_stream,
before_attn=True,
qkv_format=qkv_format,
cu_seqlens_padded=cu_seqlens_q_padded,
)
flash_attn_bwd = None
......@@ -3510,7 +3714,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
fa_backward_kwargs["window_size"] = ctx.window_size
fa_backward_kwargs["deterministic"] = ctx.deterministic
else:
if ctx.qkv_format == "thd":
if qkv_format == "thd":
from transformer_engine.pytorch.attention.dot_product_attention.backends import (
_flash_attn_varlen_bwd,
)
......@@ -3579,7 +3783,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
fa_backward_args_thd = get_fa_args(
False,
ctx.use_flash_attn_3,
ctx.qkv_format,
qkv_format,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=ctx.max_seqlen_q,
......@@ -3604,12 +3808,20 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dq.device)
dq, dk, dv = flash_attn_a2a_communicate(
[dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False
[dq, dk, dv],
chunk_ids_for_a2a,
seq_dim,
cp_size,
ctx.cp_group,
ctx.cp_stream,
before_attn=False,
qkv_format=qkv_format,
cu_seqlens_padded=cu_seqlens_q_padded,
)
if ctx.qkv_format == "bshd":
if qkv_format == "bshd":
dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
elif ctx.qkv_format == "sbhd":
elif qkv_format == "sbhd":
dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]
d_bias = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment