Unverified Commit 94f54d71 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

[PyTorch] upgrade context parallelism implementations (#572)



* try to use cuDNN fused attention for context parallelism
Signed-off-by: default avatarxren <xren@nvidia.com>

* assert CP is only supported with NVTE_F16_arbitrary_seqlen
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* port fused attn api to context parallelism
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add one more assert
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert CP does not support padded tokens
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add qkv_format into CP implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove qkv_format from CP function
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix qkv_for,at
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix bwd error with FA v2
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make cp implementation support non-causal masking
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant asserts for CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor assert information change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert core attn bias has not been supported with CP yet
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make CP work with window_sizes of [-1, -1] and [-1, 0]
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add draft code for fa test with cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* move fused attn test to a specific folder
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add assert_close to flash attn cp test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add more tests for CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add optional arguments for FA v2.4+
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add skip condition for CP test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* class and function naming fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* docstring fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not use fused attn if backend does not work with CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* create a separate folder for CP test as it needs multi-GPUs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add attn_mask_type check in attn_forwrad_func_with_cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code format fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarxren <xren@nvidia.com>
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent bb759adc
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
...@@ -10,6 +10,6 @@ pip install pytest==6.2.5 onnxruntime==1.13.1 ...@@ -10,6 +10,6 @@ pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os, sys
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
from test_fused_attn_with_cp import model_configs
dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16}
def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend='FlashAttention'):
"""Test DotProductAttention module with context parallelism"""
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)
print(f"[INFO] world_size:{world_size}, rank:{rank}")
dist.init_process_group(backend='nccl', world_size=world_size, rank=rank)
# create flash attn comm group for CP
cp_comm_ranks = range(world_size)
assert(rank in cp_comm_ranks)
cp_comm_group = dist.new_group(cp_comm_ranks, backend='nccl')
config = model_configs[model]
assert config.attn_mask_type in ['causal', 'no_mask'], f"{config.attn_mask_type} is an unsupported attention mask type!"
# instantiate core attn module
core_attn = DotProductAttention(config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type)
core_attn = core_attn.cuda()
# create flash attn inputs
if qkv_format == "bshd":
q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim)
kv_input_shape = (config.batch_size, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim)
attn_output_shape = (config.batch_size, config.max_seqlen_q, config.num_heads*config.head_dim)
elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim)
kv_input_shape = (config.max_seqlen_kv, config.batch_size, config.num_gqa_groups, config.head_dim)
attn_output_shape = (config.max_seqlen_q, config.batch_size, config.num_heads*config.head_dim)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
# make sure all GPU ranks have same inputs
for x in [q, k, v, dout]:
dist.broadcast(x, 0, group=cp_comm_group)
# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
out = core_attn(q, k, v)
out.backward(dout)
# run core_attn wit CP
q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]]
seq_dim = qkv_format.index('s')
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q_, k_, v_, dout_]]
seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
out_ = core_attn(q_, k_, v_)
out_.backward(dout_)
for x in [out_, q_.grad, k_.grad, v_.grad]:
assert(torch.all(~torch.isnan(x)))
assert(torch.all(~torch.isinf(x)))
# compare results with and without CP
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == 'bf16':
tols = dict(atol=2.5e-2, rtol=2.5e-2)
dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q.grad, k.grad, v.grad, out]]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \
for x in [q_.grad, k_.grad, v_.grad, out_]]
if qkv_format == "bshd":
torch.testing.assert_close(out_[:, 0], out[:, 0], **tols)
torch.testing.assert_close(dq_[:, 0], dq[:, 0], **tols)
torch.testing.assert_close(dk_[:, 0], dk[:, 0], **tols)
torch.testing.assert_close(dv_[:, 0], dv[:, 0], **tols)
torch.testing.assert_close(out_[:, 1], out[:, 1], **tols)
torch.testing.assert_close(dq_[:, 1], dq[:, 1], **tols)
torch.testing.assert_close(dk_[:, 1], dk[:, 1], **tols)
torch.testing.assert_close(dv_[:, 1], dv[:, 1], **tols)
elif qkv_format == "sbhd":
torch.testing.assert_close(out_[0], out[0], **tols)
torch.testing.assert_close(dq_[0], dq[0], **tols)
torch.testing.assert_close(dk_[0], dk[0], **tols)
torch.testing.assert_close(dv_[0], dv[0], **tols)
torch.testing.assert_close(out_[1], out[1], **tols)
torch.testing.assert_close(dq_[1], dq[1], **tols)
torch.testing.assert_close(dk_[1], dk[1], **tols)
torch.testing.assert_close(dv_[1], dv[1], **tols)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
def main(**kwargs):
run_dpa_with_cp(**kwargs)
if __name__ == "__main__":
kwargs = dict(arg.split('=') for arg in sys.argv[2:])
main(**kwargs)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import pytest
import subprocess
from test_fused_attn import (
ModelConfig,
_is_flash_attention_2_available,
_cudnn_version,
)
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(1, 12, 12, 128, 16384, 16384, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(1, 12, 12, 128, 16384, 16384, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(1, 12, 1, 128, 16384, 16384, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(1, 12, 1, 128, 16384, 16384, 0.0, "no_mask", "no_bias"), # GQA
}
def get_bash_arguments(**kwargs):
args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=2"]
te_path = os.getenv("TE_PATH", "/opt/transformerengine")
script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py")
args.append(script_path)
for k, v in kwargs.items():
args.append(f"{k}={v}")
return args
@pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
def test_cp_with_flash_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend='FlashAttention'
),
check=True
)
@pytest.mark.skipif(_cudnn_version() < (8,9,7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
def test_cp_with_fused_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend='FusedAttention'
),
check=True
)
...@@ -360,7 +360,7 @@ class UnpackTensor(torch.autograd.Function): ...@@ -360,7 +360,7 @@ class UnpackTensor(torch.autograd.Function):
def flash_attn_p2p_communicate(rank, send_tensor, send_dst, def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
recv_tensor, recv_src, recv_tensor, recv_src,
cp_group, batch_p2p_comm): cp_group, batch_p2p_comm):
"""Point-to-point communications of KV and dKV in Flash Attention with context parallelism""" """Point-to-point communications of KV and dKV in Attention with context parallelism"""
send_recv_ops = [] send_recv_ops = []
if batch_p2p_comm: if batch_p2p_comm:
...@@ -405,7 +405,7 @@ def flash_attn_p2p_communicate(rank, send_tensor, send_dst, ...@@ -405,7 +405,7 @@ def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
@jit_fuser @jit_fuser
def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_per_step): def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_per_step):
"""Merge partial outputs of each step in Flash Attention with context parallelism""" """Merge partial outputs of each step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).transpose(1, 2) softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).transpose(1, 2)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step*softmax_lse_corrected_exp out_corrected = out_per_step*softmax_lse_corrected_exp
...@@ -414,22 +414,23 @@ def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_pe ...@@ -414,22 +414,23 @@ def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_pe
@jit_fuser @jit_fuser
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step): def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
"""Merge softmax stats of each step in Flash Attention with context parallelism""" """Merge softmax stats of each step in Attention with context parallelism"""
softmax_lse.exp_() softmax_lse.exp_()
softmax_lse.add_(softmax_lse_per_step.to(torch.double).exp()) softmax_lse.add_(softmax_lse_per_step.to(torch.double).exp())
softmax_lse.log_() softmax_lse.log_()
class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): class AttnFuncWithCP(torch.autograd.Function):
""" """
Flash Attention implementation with context parallelism. Attention implementation with context parallelism.
Split flash attention compute into multiple steps, and overlap current-step Split attention compute into multiple steps, and overlap current-step
compute with next-step communication. compute with next-step communication.
""" """
@staticmethod @staticmethod
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
cp_group, cp_global_ranks, cp_stream, softmax_scale, causal, deterministic): dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, attn_mask_type,
deterministic, use_fused_attention):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -439,9 +440,17 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -439,9 +440,17 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
recv_src = cp_global_ranks[(rank + cp_size - 1) % cp_size] recv_src = cp_global_ranks[(rank + cp_size - 1) % cp_size]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
# [b, s, np, hn] -> [b, 2, s//2, np, hn] causal = (attn_mask_type == "causal")
q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]]
if causal:
# [b, s, np, hn] -> [b, 2, s//2, np, hn]
q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]]
assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8" assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8"
fa_optional_forward_kwargs = {}
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1]
if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None
# Flash Attn inputs # Flash Attn inputs
q_inputs = [None, None] q_inputs = [None, None]
...@@ -479,71 +488,143 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -479,71 +488,143 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
kv_inputs[i%2] = p2p_comm_buffers[i] kv_inputs[i%2] = p2p_comm_buffers[i]
if causal: if causal:
fa_forward_kwargs = {}
if _flash_attn_2_3_plus:
fa_forward_kwargs["window_size"] = (-1, -1)
if _flash_attn_2_4_plus:
fa_forward_kwargs["alibi_slopes"] = None
fa_forward_kwargs["return_softmax"]=False
if i == 0: if i == 0:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn] if use_fused_attention:
q_inputs[i%2] = q.view(-1, *q.shape[-2:]) # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
_, _, _, _, out_per_step[i], \ kv_inputs[i%2] = kv_inputs[i%2].view(
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward( 2, k.shape[0], -1, *k.shape[-2:])
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1], out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, fused_attn_fwd(
dropout_p, softmax_scale, causal=True, **fa_forward_kwargs, is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
) cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="causal",
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=True, return_softmax=False,
**fa_optional_forward_kwargs
)
elif i <= rank: elif i <= rank:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn] if use_fused_attention:
q_inputs[i%2] = q.view(-1, *q.shape[-2:]) # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous() # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
_, _, _, _, out_per_step[i], \ out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward( fused_attn_fwd(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1], is_training, max_seqlen_q, max_seqlen_k//2, cu_seqlens_q,
cu_seqlens_q, cu_seqlens_k//2, max_seqlen_q, max_seqlen_k//2, cu_seqlens_k//2, q_inputs[i%2], kv_inputs[i%2][0],
dropout_p, softmax_scale, causal=False, **fa_forward_kwargs, kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = [-1, -1]
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q, cu_seqlens_k//2, max_seqlen_q, max_seqlen_k//2,
dropout_p, softmax_scale, causal=False, return_softmax=False,
**fa_optional_forward_kwargs
)
else:
if use_fused_attention:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_inputs[i%2] = q[:, 1, ...].contiguous()
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(
2, k.shape[0], -1, *k.shape[-2:])
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
fused_attn_fwd(
is_training, max_seqlen_q//2, max_seqlen_k, cu_seqlens_q//2,
cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = [-1, -1]
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q//2, cu_seqlens_k, max_seqlen_q//2, max_seqlen_k,
dropout_p, softmax_scale, causal=False, return_softmax=False,
**fa_optional_forward_kwargs
)
else:
if use_fused_attention:
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
cu_seqlens_k, q, kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
) )
else: else:
# [b, sq//2, np, hn] -> [b*sq//2, np, hn] # [b, sq, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) q_inputs[i%2] = q.view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
_, _, _, _, out_per_step[i], \ _, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward( softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1], q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q//2, cu_seqlens_k, max_seqlen_q//2, max_seqlen_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=False, **fa_forward_kwargs, dropout_p, softmax_scale, causal=False, return_softmax=False,
**fa_optional_forward_kwargs
) )
else:
assert False, "Not implemented yet!"
if i > 0: if i > 0:
# wait until fwd restuls correction of last step is done # wait until fwd restuls correction of last step is done
if i > 1: if i > 1:
flash_attn_streams[(i-1)%2].wait_event(fwd_results_correction_done) flash_attn_streams[(i-1)%2].wait_event(fwd_results_correction_done)
if use_fused_attention:
# [b, np, sq, 1] -> [b, np, sq]
softmax_lse_per_step[i-1].squeeze_(-1)
with torch.cuda.stream(flash_attn_streams[(i-1)%2]): with torch.cuda.stream(flash_attn_streams[(i-1)%2]):
if causal: if i == 1:
if i == 1: out = torch.empty_like(q).zero_()
out = torch.empty_like(q).zero_() softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) if causal:
# [b, np, sq] -> [b, np, 2, sq//2] # [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view( softmax_lse_ = softmax_lse.view(
*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2 *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2
) )
elif (i-1) <= rank: elif (i-1) <= rank or not causal:
flash_attn_fwd_softmax_lse_correction(softmax_lse, flash_attn_fwd_softmax_lse_correction(softmax_lse,
softmax_lse_per_step[i-1]) softmax_lse_per_step[i-1])
else:
flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :],
softmax_lse_per_step[i-1])
else: else:
assert False, "Not implemented yet!" flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :],
softmax_lse_per_step[i-1])
if i < cp_size: if i < cp_size:
flash_attn_streams[(i-1)%2].record_event(fwd_results_correction_done) flash_attn_streams[(i-1)%2].record_event(fwd_results_correction_done)
...@@ -554,7 +635,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -554,7 +635,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
for i in range(cp_size): for i in range(cp_size):
# [b*sq, np, hn] -> [b, sq, np, hn] or [b*sq//2, np, hn] -> [b, sq//2, np, hn] # [b*sq, np, hn] -> [b, sq, np, hn] or [b*sq//2, np, hn] -> [b, sq//2, np, hn]
out_ = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:]) out_ = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
if i <= rank: if i <= rank or not causal:
flash_attn_fwd_out_correction(out.view(*out_.shape), flash_attn_fwd_out_correction(out.view(*out_.shape),
out_, out_,
softmax_lse, softmax_lse,
...@@ -566,7 +647,10 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -566,7 +647,10 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
softmax_lse_per_step[i]) softmax_lse_per_step[i])
kv = p2p_comm_buffers[-1] kv = p2p_comm_buffers[-1]
out = out.view(-1, *out.shape[-2:]) if use_fused_attention:
out = out.view(out.shape[0], -1, *out.shape[-2:])
else:
out = out.view(-1, *out.shape[-2:])
ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k)
ctx.rng_states = rng_states ctx.rng_states = rng_states
ctx.cp_group = cp_group ctx.cp_group = cp_group
...@@ -577,6 +661,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -577,6 +661,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.deterministic = deterministic ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
return out return out
@staticmethod @staticmethod
...@@ -589,10 +674,16 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -589,10 +674,16 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
# [b, np, sq] -> [b, np, 2, sq//2] if ctx.causal:
softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2) # [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse_[..., 1, :].contiguous() softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
# [b*sq, np, hn] -> [b, 2, sq//2, np, hn] softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
if ctx.use_fused_attention:
# [b, np, sq//2] -> [b, np, sq//2, 1]
softmax_lse_.unsqueeze_(-1)
if ctx.use_fused_attention:
# [b, np, sq] -> [b, np, sq, 1]
softmax_lse.unsqueeze_(-1)
out = out.view(*q.shape) out = out.view(*q.shape)
dout = dout.view(*q.shape) dout = dout.view(*q.shape)
# Flash Attn outputs # Flash Attn outputs
...@@ -603,6 +694,12 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -603,6 +694,12 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
p2p_comm_buffers[0][0].copy_(kv) p2p_comm_buffers[0][0].copy_(kv)
send_recv_reqs = [] send_recv_reqs = []
fa_optional_backward_kwargs = {}
if _flash_attn_2_4_plus:
fa_optional_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
fa_optional_backward_kwargs["deterministic"] = ctx.deterministic
for i in range(cp_size): for i in range(cp_size):
# wait until KV is received # wait until KV is received
for req in send_recv_reqs: for req in send_recv_reqs:
...@@ -628,73 +725,168 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -628,73 +725,168 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
kv = p2p_comm_buffers[i%2][0] kv = p2p_comm_buffers[i%2][0]
# In reversed order of fwd # In reversed order of fwd
if ctx.causal: if ctx.causal:
fa_backward_kwargs = {}
if _flash_attn_2_3_plus:
fa_backward_kwargs["window_size"] = (-1, -1)
if _flash_attn_2_4_plus:
fa_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
fa_backward_kwargs["deterministic"] = ctx.deterministic
fa_backward_kwargs["rng_state"]=ctx.rng_states[cp_size-i-1]
if i == (cp_size-1): if i == (cp_size-1):
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn] if ctx.use_fused_attention:
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_ = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
out_ = out.view(out.shape[0], -1, *out.shape[-2:])
dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k,
cu_seqlens_q, cu_seqlens_k,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout="bshd_bshd_bshd",
attn_mask_type="causal",
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, 0]
_flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, True,
rng_state=ctx.rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
)
elif i >= (cp_size-rank-1):
if ctx.use_fused_attention:
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_ = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous()
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
out_ = out.view(out.shape[0], -1, *out.shape[-2:])
dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k//2,
cu_seqlens_q, cu_seqlens_k//2,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout="bshd_bshd_bshd",
attn_mask_type="no_mask",
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1]
_flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k//2,
ctx.max_seqlen_q, ctx.max_seqlen_k//2,
ctx.dropout_p, ctx.softmax_scale, False,
rng_state=ctx.rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
)
else:
if ctx.use_fused_attention:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_ = q[:, 1, ...].contiguous()
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
out_ = out[:, 1, ...].contiguous()
dout_ = dout[:, 1, ...].contiguous()
dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q//2, ctx.max_seqlen_k,
cu_seqlens_q//2, cu_seqlens_k,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
[softmax_lse_, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout="bshd_bshd_bshd",
attn_mask_type="no_mask",
)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1]
_flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse_,
dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k,
ctx.max_seqlen_q//2, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, False,
rng_state=ctx.rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
)
else:
if ctx.use_fused_attention:
dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k,
cu_seqlens_q, cu_seqlens_k,
q, kv[0], kv[1], out, dout, TE_DType[q.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout="bshd_bshd_bshd",
attn_mask_type="no_mask",
)
else:
# [b, sq, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:]) q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_) dq_ = torch.empty_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:]) kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_) dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn] # [b, sq, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:]) out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1]
_flash_attn_backward( _flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse, dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k, dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.max_seqlen_q, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, True,
**fa_backward_kwargs,
)
elif i >= (cp_size-rank-1):
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
_flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k//2,
ctx.max_seqlen_q, ctx.max_seqlen_k//2,
ctx.dropout_p, ctx.softmax_scale, False, ctx.dropout_p, ctx.softmax_scale, False,
**fa_backward_kwargs, **fa_optional_backward_kwargs
)
else:
# [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, sq//2, np, hn] -> [b*sq//2, np, hn]
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
_flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse_,
dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k,
ctx.max_seqlen_q//2, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, False,
**fa_backward_kwargs,
) )
if i >= (cp_size-rank-1): if i >= (cp_size-rank-1) or not ctx.causal:
# [b*sq, np, hn] -> [b, 2, sq//2, np, hn] # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
dq_ = dq_.view(*dq.shape) # [b*sq, np, hn] -> [b, sq, np, hn] if not causal
else: dq_ = dq_.view(*dq.shape)
# [b*sq//2, np, hn] -> [b, sq//2, np, hn] else:
dq_ = dq_.view(dq.shape[0], *dq.shape[2:]) # [b*sq//2, np, hn] -> [b, sq//2, np, hn]
dq_ = dq_.view(dq.shape[0], *dq.shape[2:])
if ctx.causal:
if i > (cp_size-rank-1): if i > (cp_size-rank-1):
dq.add_(dq_) dq.add_(dq_)
elif i == (cp_size-rank-1): elif i == (cp_size-rank-1):
...@@ -707,19 +899,28 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -707,19 +899,28 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
dq[:, 1, ...].add_(dq_) dq[:, 1, ...].add_(dq_)
else: else:
dq[:, 1, ...].copy_(dq_) dq[:, 1, ...].copy_(dq_)
else:
if i == 0:
dq.copy_(dq_)
else:
dq.add_(dq_)
# wait until dKV is received # wait until dKV is received
for req in send_recv_reqs: for req in send_recv_reqs:
req.wait() req.wait()
dkv = p2p_comm_buffers[(i+1)%2][1] dkv = p2p_comm_buffers[(i+1)%2][1]
if i >= (cp_size-rank-1) and i != (cp_size-1): if ctx.use_fused_attention:
# [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn] dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:]) if ctx.causal and i >= (cp_size-rank-1) and i != (cp_size-1):
else: # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
# [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
dkv_ = dkv_.view(*dkv.shape) else:
# [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal
# [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
dkv_ = dkv_.view(*dkv.shape)
if ctx.causal:
if i == (cp_size-1): if i == (cp_size-1):
if rank == 0: if rank == 0:
dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]) dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
...@@ -736,24 +937,32 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -736,24 +937,32 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
else: else:
dkv.copy_(dkv_) dkv.copy_(dkv_)
else: else:
assert False, "Not implemented yet!" if i == 0:
dkv.copy_(dkv_)
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] else:
dq = dq.view(q.shape[0], -1, *q.shape[-2:]) dkv.add_(dkv_)
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:]) if ctx.causal:
return dq, dkv[0], dkv[1], None, None, None, None, None, None, None, None, None, None, None # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
dq = dq.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
def flash_attn_forward_func_with_cp(q, k, v, cu_seqlens_q, cu_seqlens_k, dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
max_seqlen_q, max_seqlen_k, dropout_p, return None, dq, dkv[0], dkv[1], None, None, None, None, None, None, \
cp_group, cp_global_ranks, cp_stream, None, None, None, None, None, None
softmax_scale=None, causal=False,
deterministic=False):
"""Flash Attention implementation with context parallelism""" def attn_forward_func_with_cp(
out = FlashAttnUnpaddedFuncWithCP.apply( is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale=None, attn_mask_type="causal",
cp_group, cp_global_ranks, cp_stream, softmax_scale, causal, deterministic deterministic=False, use_fused_attention=False
) -> torch.Tensor:
"""Attention implementation with context parallelism"""
assert (attn_mask_type in ["causal", "no_mask"]
), f"Mask type of {attn_mask_type} is not supported with context parallelism!"
out = AttnFuncWithCP.apply(
is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, attn_mask_type,
deterministic, use_fused_attention
) )
return out return out
...@@ -1343,7 +1552,7 @@ class FlashAttention(torch.nn.Module): ...@@ -1343,7 +1552,7 @@ class FlashAttention(torch.nn.Module):
] ]
if 'padding' in attn_mask_type: if 'padding' in attn_mask_type:
assert not context_parallel, "Padding mask not supported with context parallelism." assert not context_parallel, "Padding mask not supported with context parallelism!"
if self.attention_type == "self": if self.attention_type == "self":
assert ( assert (
...@@ -1402,7 +1611,7 @@ class FlashAttention(torch.nn.Module): ...@@ -1402,7 +1611,7 @@ class FlashAttention(torch.nn.Module):
else: else:
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
elif qkv_format == 'thd': elif qkv_format == 'thd':
assert not context_parallel, "thd format is not supported for context parallelism!" assert not context_parallel, "thd format not supported with context parallelism!"
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
if max_seqlen_q is None: if max_seqlen_q is None:
...@@ -1420,13 +1629,13 @@ class FlashAttention(torch.nn.Module): ...@@ -1420,13 +1629,13 @@ class FlashAttention(torch.nn.Module):
alibi_slopes is None alibi_slopes is None
), "Alibi slope bias addition is not supported with context parallelism." ), "Alibi slope bias addition is not supported with context parallelism."
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
output = flash_attn_forward_func_with_cp( output = attn_forward_func_with_cp(
query_layer, key_layer, value_layer, self.training, query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream, cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor, softmax_scale=1.0/self.norm_factor,
causal="causal" in attn_mask_type, attn_mask_type=attn_mask_type,
deterministic=self.deterministic deterministic=self.deterministic
) )
else: else:
...@@ -1755,6 +1964,9 @@ class FusedAttention(torch.nn.Module): ...@@ -1755,6 +1964,9 @@ class FusedAttention(torch.nn.Module):
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""fused attention fprop""" """fused attention fprop"""
...@@ -1773,6 +1985,8 @@ class FusedAttention(torch.nn.Module): ...@@ -1773,6 +1985,8 @@ class FusedAttention(torch.nn.Module):
qkv_layout in QKVLayouts qkv_layout in QKVLayouts
), f"FusedAttention does not support qkv_layout = {qkv_layout}!" ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
assert ( assert (
qkv_format != 'thd' qkv_format != 'thd'
...@@ -1786,6 +2000,8 @@ class FusedAttention(torch.nn.Module): ...@@ -1786,6 +2000,8 @@ class FusedAttention(torch.nn.Module):
batch_size, max_seqlen_q, max_seqlen_kv = ( batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[0], query_layer.shape[1], key_layer.shape[1]) query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
if 'padding' in attn_mask_type: if 'padding' in attn_mask_type:
assert not context_parallel, "Padding mask not supported with context parallelism!"
global _cu_seqlens_q, _cu_seqlens_kv global _cu_seqlens_q, _cu_seqlens_kv
if (cu_seqlens_q is not None and cu_seqlens_kv is not None): if (cu_seqlens_q is not None and cu_seqlens_kv is not None):
# use cu_seqlens when both cu_seqlens and attention_mask are present # use cu_seqlens when both cu_seqlens and attention_mask are present
...@@ -1829,24 +2045,49 @@ class FusedAttention(torch.nn.Module): ...@@ -1829,24 +2045,49 @@ class FusedAttention(torch.nn.Module):
and (core_attention_bias_type == "no_bias") and (core_attention_bias_type == "no_bias")
and (fused_attention_backend and (fused_attention_backend
== tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)) == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen))
with self.attention_dropout_ctx():
output = FusedAttnFunc.apply( if context_parallel:
self.training, assert (fused_attention_backend
max_seqlen_q, max_seqlen_kv, == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
cu_seqlens_q, cu_seqlens_kv, ), f"{fused_attention_backend} does not work with context parallelism!"
query_layer, key_layer, value_layer, assert (core_attention_bias_type == "no_bias"), \
qkv_dtype, "Core attention bias has not been supported with context parallelism yet!"
core_attention_bias, if qkv_format == 'sbhd':
1.0/self.norm_factor, query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
self.attention_dropout if self.training else 0.0, for x in (query_layer, key_layer, value_layer)]
fast_zero_fill, with self.attention_dropout_ctx():
qkv_layout, output = attn_forward_func_with_cp(
core_attention_bias_type, self.training,
attn_mask_type, query_layer, key_layer, value_layer,
None, # rng_gen cu_seqlens_q, cu_seqlens_kv,
fused_attention_backend, max_seqlen_q, max_seqlen_kv,
use_FAv2_bwd, self.attention_dropout if self.training else 0.0,
) cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor,
attn_mask_type=attn_mask_type,
use_fused_attention=True,
)
if qkv_format == 'sbhd':
output = output.transpose(0,1).contiguous()
else:
with self.attention_dropout_ctx():
output = FusedAttnFunc.apply(
self.training,
max_seqlen_q, max_seqlen_kv,
cu_seqlens_q, cu_seqlens_kv,
query_layer, key_layer, value_layer,
qkv_dtype,
core_attention_bias,
1.0/self.norm_factor,
self.attention_dropout if self.training else 0.0,
fast_zero_fill,
qkv_layout,
core_attention_bias_type,
attn_mask_type,
None, # rng_gen
fused_attention_backend,
use_FAv2_bwd,
)
# ...hd -> ...(hd) # ...hd -> ...(hd)
return output.view(*output.shape[:-2], -1) return output.view(*output.shape[:-2], -1)
...@@ -2315,12 +2556,13 @@ class DotProductAttention(torch.nn.Module): ...@@ -2315,12 +2556,13 @@ class DotProductAttention(torch.nn.Module):
if core_attention_bias_type != "no_bias" or core_attention_bias is not None: if core_attention_bias_type != "no_bias" or core_attention_bias is not None:
use_flash_attention = False use_flash_attention = False
context_parallel = (self.cp_group is not None and \
get_distributed_world_size(self.cp_group) != 1)
# Filter: sliding window attention. # Filter: sliding window attention.
# UnfusedDotProductAttention can support SWA via arbitrary attention mask. # UnfusedDotProductAttention can support SWA via arbitrary attention mask.
if window_size not in ((-1, -1), (-1, 0)): if window_size not in ((-1, -1), (-1, 0)):
use_fused_attention = False use_fused_attention = False
context_parallel = (self.cp_group is not None
and get_distributed_world_size(self.cp_group) != 1)
if (not _flash_attn_2_3_plus) or context_parallel: if (not _flash_attn_2_3_plus) or context_parallel:
use_flash_attention = False use_flash_attention = False
...@@ -2361,8 +2603,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -2361,8 +2603,10 @@ class DotProductAttention(torch.nn.Module):
# DPA does not support FP8; for FP8, use cpp_extensions modules directly # DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = (fused_attention_backend in is_backend_avail = (fused_attention_backend in
[FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]]) [FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]])
use_fused_attention = (use_fused_attention use_fused_attention = ( \
and is_backend_avail) use_fused_attention and is_backend_avail and \
(not context_parallel or \
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]))
# Filter: Alibi slopes # Filter: Alibi slopes
if alibi_slopes is not None: if alibi_slopes is not None:
...@@ -2415,42 +2659,51 @@ class DotProductAttention(torch.nn.Module): ...@@ -2415,42 +2659,51 @@ class DotProductAttention(torch.nn.Module):
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv) max_seqlen_kv=max_seqlen_kv)
assert (
self.cp_group is None or get_distributed_world_size(self.cp_group) == 1
), "Context parallelism is only implemented with Flash Attention!"
if use_fused_attention: if use_fused_attention:
if _NVTE_DEBUG: if _NVTE_DEBUG:
print("[DotProductAttention]: using cuDNN fused attention (backend " print("[DotProductAttention]: using cuDNN fused attention (backend "
+ str(int(fused_attention_backend)) + ")") + str(int(fused_attention_backend)) + ")")
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward(self.fused_attention, return self._checkpointed_attention_forward(
query_layer, self.fused_attention,
key_layer, query_layer,
value_layer, key_layer,
qkv_layout = qkv_layout, value_layer,
cu_seqlens_q = cu_seqlens_q, qkv_layout=qkv_layout,
cu_seqlens_kv = cu_seqlens_kv, cu_seqlens_q=cu_seqlens_q,
attn_mask_type = attn_mask_type, cu_seqlens_kv=cu_seqlens_kv,
attention_mask = attention_mask, attn_mask_type=attn_mask_type,
fused_attention_backend = fused_attention_backend, attention_mask=attention_mask,
core_attention_bias_type = core_attention_bias_type, fused_attention_backend=fused_attention_backend,
core_attention_bias = core_attention_bias, core_attention_bias_type=core_attention_bias_type,
fast_zero_fill = fast_zero_fill, core_attention_bias=core_attention_bias,
max_seqlen_q=max_seqlen_q, fast_zero_fill=fast_zero_fill,
max_seqlen_kv=max_seqlen_kv) cp_group=self.cp_group,
return self.fused_attention(query_layer, key_layer, value_layer, cp_global_ranks=self.cp_global_ranks,
qkv_layout = qkv_layout, cp_stream=self.cp_stream,
cu_seqlens_q = cu_seqlens_q, max_seqlen_q=max_seqlen_q,
cu_seqlens_kv = cu_seqlens_kv, max_seqlen_kv=max_seqlen_kv)
attn_mask_type = attn_mask_type, return self.fused_attention(
attention_mask = attention_mask, query_layer,
fused_attention_backend = fused_attention_backend, key_layer,
core_attention_bias_type = core_attention_bias_type, value_layer,
core_attention_bias = core_attention_bias, qkv_layout=qkv_layout,
fast_zero_fill = fast_zero_fill, cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q, cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_kv=max_seqlen_kv) attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
fast_zero_fill=fast_zero_fill,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv)
assert (not context_parallel), \
"Context parallelism is only implemented with Flash Attention and Fused Attention!"
if _NVTE_DEBUG: if _NVTE_DEBUG:
print("[DotProductAttention]: using unfused DPA") print("[DotProductAttention]: using unfused DPA")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment