Unverified Commit 94f54d71 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

[PyTorch] upgrade context parallelism implementations (#572)



* try to use cuDNN fused attention for context parallelism
Signed-off-by: default avatarxren <xren@nvidia.com>

* assert CP is only supported with NVTE_F16_arbitrary_seqlen
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* port fused attn api to context parallelism
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add one more assert
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert CP does not support padded tokens
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add qkv_format into CP implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove qkv_format from CP function
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix qkv_for,at
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix bwd error with FA v2
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make cp implementation support non-causal masking
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant asserts for CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor assert information change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert core attn bias has not been supported with CP yet
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make CP work with window_sizes of [-1, -1] and [-1, 0]
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add draft code for fa test with cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* move fused attn test to a specific folder
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add assert_close to flash attn cp test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add more tests for CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add optional arguments for FA v2.4+
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add skip condition for CP test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* class and function naming fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* docstring fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not use fused attn if backend does not work with CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* create a separate folder for CP test as it needs multi-GPUs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add attn_mask_type check in attn_forwrad_func_with_cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code format fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarxren <xren@nvidia.com>
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent bb759adc
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
......@@ -10,6 +10,6 @@ pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os, sys
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
from test_fused_attn_with_cp import model_configs
dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16}
def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend='FlashAttention'):
"""Test DotProductAttention module with context parallelism"""
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)
print(f"[INFO] world_size:{world_size}, rank:{rank}")
dist.init_process_group(backend='nccl', world_size=world_size, rank=rank)
# create flash attn comm group for CP
cp_comm_ranks = range(world_size)
assert(rank in cp_comm_ranks)
cp_comm_group = dist.new_group(cp_comm_ranks, backend='nccl')
config = model_configs[model]
assert config.attn_mask_type in ['causal', 'no_mask'], f"{config.attn_mask_type} is an unsupported attention mask type!"
# instantiate core attn module
core_attn = DotProductAttention(config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type)
core_attn = core_attn.cuda()
# create flash attn inputs
if qkv_format == "bshd":
q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim)
kv_input_shape = (config.batch_size, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim)
attn_output_shape = (config.batch_size, config.max_seqlen_q, config.num_heads*config.head_dim)
elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim)
kv_input_shape = (config.max_seqlen_kv, config.batch_size, config.num_gqa_groups, config.head_dim)
attn_output_shape = (config.max_seqlen_q, config.batch_size, config.num_heads*config.head_dim)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
# make sure all GPU ranks have same inputs
for x in [q, k, v, dout]:
dist.broadcast(x, 0, group=cp_comm_group)
# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
out = core_attn(q, k, v)
out.backward(dout)
# run core_attn wit CP
q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]]
seq_dim = qkv_format.index('s')
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q_, k_, v_, dout_]]
seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
out_ = core_attn(q_, k_, v_)
out_.backward(dout_)
for x in [out_, q_.grad, k_.grad, v_.grad]:
assert(torch.all(~torch.isnan(x)))
assert(torch.all(~torch.isinf(x)))
# compare results with and without CP
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == 'bf16':
tols = dict(atol=2.5e-2, rtol=2.5e-2)
dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q.grad, k.grad, v.grad, out]]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \
for x in [q_.grad, k_.grad, v_.grad, out_]]
if qkv_format == "bshd":
torch.testing.assert_close(out_[:, 0], out[:, 0], **tols)
torch.testing.assert_close(dq_[:, 0], dq[:, 0], **tols)
torch.testing.assert_close(dk_[:, 0], dk[:, 0], **tols)
torch.testing.assert_close(dv_[:, 0], dv[:, 0], **tols)
torch.testing.assert_close(out_[:, 1], out[:, 1], **tols)
torch.testing.assert_close(dq_[:, 1], dq[:, 1], **tols)
torch.testing.assert_close(dk_[:, 1], dk[:, 1], **tols)
torch.testing.assert_close(dv_[:, 1], dv[:, 1], **tols)
elif qkv_format == "sbhd":
torch.testing.assert_close(out_[0], out[0], **tols)
torch.testing.assert_close(dq_[0], dq[0], **tols)
torch.testing.assert_close(dk_[0], dk[0], **tols)
torch.testing.assert_close(dv_[0], dv[0], **tols)
torch.testing.assert_close(out_[1], out[1], **tols)
torch.testing.assert_close(dq_[1], dq[1], **tols)
torch.testing.assert_close(dk_[1], dk[1], **tols)
torch.testing.assert_close(dv_[1], dv[1], **tols)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
def main(**kwargs):
run_dpa_with_cp(**kwargs)
if __name__ == "__main__":
kwargs = dict(arg.split('=') for arg in sys.argv[2:])
main(**kwargs)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import pytest
import subprocess
from test_fused_attn import (
ModelConfig,
_is_flash_attention_2_available,
_cudnn_version,
)
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(1, 12, 12, 128, 16384, 16384, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(1, 12, 12, 128, 16384, 16384, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(1, 12, 1, 128, 16384, 16384, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(1, 12, 1, 128, 16384, 16384, 0.0, "no_mask", "no_bias"), # GQA
}
def get_bash_arguments(**kwargs):
args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=2"]
te_path = os.getenv("TE_PATH", "/opt/transformerengine")
script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py")
args.append(script_path)
for k, v in kwargs.items():
args.append(f"{k}={v}")
return args
@pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
def test_cp_with_flash_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend='FlashAttention'
),
check=True
)
@pytest.mark.skipif(_cudnn_version() < (8,9,7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
def test_cp_with_fused_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend='FusedAttention'
),
check=True
)
......@@ -360,7 +360,7 @@ class UnpackTensor(torch.autograd.Function):
def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
recv_tensor, recv_src,
cp_group, batch_p2p_comm):
"""Point-to-point communications of KV and dKV in Flash Attention with context parallelism"""
"""Point-to-point communications of KV and dKV in Attention with context parallelism"""
send_recv_ops = []
if batch_p2p_comm:
......@@ -405,7 +405,7 @@ def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
@jit_fuser
def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_per_step):
"""Merge partial outputs of each step in Flash Attention with context parallelism"""
"""Merge partial outputs of each step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).transpose(1, 2)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step*softmax_lse_corrected_exp
......@@ -414,22 +414,23 @@ def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_pe
@jit_fuser
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
"""Merge softmax stats of each step in Flash Attention with context parallelism"""
"""Merge softmax stats of each step in Attention with context parallelism"""
softmax_lse.exp_()
softmax_lse.add_(softmax_lse_per_step.to(torch.double).exp())
softmax_lse.log_()
class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
class AttnFuncWithCP(torch.autograd.Function):
"""
Flash Attention implementation with context parallelism.
Split flash attention compute into multiple steps, and overlap current-step
Attention implementation with context parallelism.
Split attention compute into multiple steps, and overlap current-step
compute with next-step communication.
"""
@staticmethod
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
cp_group, cp_global_ranks, cp_stream, softmax_scale, causal, deterministic):
def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, attn_mask_type,
deterministic, use_fused_attention):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -439,9 +440,17 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
recv_src = cp_global_ranks[(rank + cp_size - 1) % cp_size]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
causal = (attn_mask_type == "causal")
if causal:
# [b, s, np, hn] -> [b, 2, s//2, np, hn]
q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]]
assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8"
fa_optional_forward_kwargs = {}
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1]
if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None
# Flash Attn inputs
q_inputs = [None, None]
......@@ -479,13 +488,23 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
kv_inputs[i%2] = p2p_comm_buffers[i]
if causal:
fa_forward_kwargs = {}
if _flash_attn_2_3_plus:
fa_forward_kwargs["window_size"] = (-1, -1)
if _flash_attn_2_4_plus:
fa_forward_kwargs["alibi_slopes"] = None
fa_forward_kwargs["return_softmax"]=False
if i == 0:
if use_fused_attention:
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(
2, k.shape[0], -1, *k.shape[-2:])
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="causal",
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
......@@ -494,56 +513,118 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=True, **fa_forward_kwargs,
dropout_p, softmax_scale, causal=True, return_softmax=False,
**fa_optional_forward_kwargs
)
elif i <= rank:
if use_fused_attention:
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_k//2, cu_seqlens_q,
cu_seqlens_k//2, q_inputs[i%2], kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = [-1, -1]
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q, cu_seqlens_k//2, max_seqlen_q, max_seqlen_k//2,
dropout_p, softmax_scale, causal=False, **fa_forward_kwargs,
dropout_p, softmax_scale, causal=False, return_softmax=False,
**fa_optional_forward_kwargs
)
else:
if use_fused_attention:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_inputs[i%2] = q[:, 1, ...].contiguous()
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(
2, k.shape[0], -1, *k.shape[-2:])
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
fused_attn_fwd(
is_training, max_seqlen_q//2, max_seqlen_k, cu_seqlens_q//2,
cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
)
else:
# [b, sq//2, np, hn] -> [b*sq//2, np, hn]
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = [-1, -1]
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q//2, cu_seqlens_k, max_seqlen_q//2, max_seqlen_k,
dropout_p, softmax_scale, causal=False, **fa_forward_kwargs,
dropout_p, softmax_scale, causal=False, return_softmax=False,
**fa_optional_forward_kwargs
)
else:
assert False, "Not implemented yet!"
if use_fused_attention:
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
cu_seqlens_k, q, kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
)
else:
# [b, sq, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
# [2, b, sk, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=False, return_softmax=False,
**fa_optional_forward_kwargs
)
if i > 0:
# wait until fwd restuls correction of last step is done
if i > 1:
flash_attn_streams[(i-1)%2].wait_event(fwd_results_correction_done)
if use_fused_attention:
# [b, np, sq, 1] -> [b, np, sq]
softmax_lse_per_step[i-1].squeeze_(-1)
with torch.cuda.stream(flash_attn_streams[(i-1)%2]):
if causal:
if i == 1:
out = torch.empty_like(q).zero_()
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
if causal:
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(
*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2
)
elif (i-1) <= rank:
elif (i-1) <= rank or not causal:
flash_attn_fwd_softmax_lse_correction(softmax_lse,
softmax_lse_per_step[i-1])
else:
flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :],
softmax_lse_per_step[i-1])
else:
assert False, "Not implemented yet!"
if i < cp_size:
flash_attn_streams[(i-1)%2].record_event(fwd_results_correction_done)
......@@ -554,7 +635,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
for i in range(cp_size):
# [b*sq, np, hn] -> [b, sq, np, hn] or [b*sq//2, np, hn] -> [b, sq//2, np, hn]
out_ = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
if i <= rank:
if i <= rank or not causal:
flash_attn_fwd_out_correction(out.view(*out_.shape),
out_,
softmax_lse,
......@@ -566,6 +647,9 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
softmax_lse_per_step[i])
kv = p2p_comm_buffers[-1]
if use_fused_attention:
out = out.view(out.shape[0], -1, *out.shape[-2:])
else:
out = out.view(-1, *out.shape[-2:])
ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k)
ctx.rng_states = rng_states
......@@ -577,6 +661,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
return out
@staticmethod
......@@ -589,10 +674,16 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
if ctx.causal:
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
# [b*sq, np, hn] -> [b, 2, sq//2, np, hn]
if ctx.use_fused_attention:
# [b, np, sq//2] -> [b, np, sq//2, 1]
softmax_lse_.unsqueeze_(-1)
if ctx.use_fused_attention:
# [b, np, sq] -> [b, np, sq, 1]
softmax_lse.unsqueeze_(-1)
out = out.view(*q.shape)
dout = dout.view(*q.shape)
# Flash Attn outputs
......@@ -603,6 +694,12 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
p2p_comm_buffers[0][0].copy_(kv)
send_recv_reqs = []
fa_optional_backward_kwargs = {}
if _flash_attn_2_4_plus:
fa_optional_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
fa_optional_backward_kwargs["deterministic"] = ctx.deterministic
for i in range(cp_size):
# wait until KV is received
for req in send_recv_reqs:
......@@ -628,15 +725,27 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
kv = p2p_comm_buffers[i%2][0]
# In reversed order of fwd
if ctx.causal:
fa_backward_kwargs = {}
if _flash_attn_2_3_plus:
fa_backward_kwargs["window_size"] = (-1, -1)
if _flash_attn_2_4_plus:
fa_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
fa_backward_kwargs["deterministic"] = ctx.deterministic
fa_backward_kwargs["rng_state"]=ctx.rng_states[cp_size-i-1]
if i == (cp_size-1):
if ctx.use_fused_attention:
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_ = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
out_ = out.view(out.shape[0], -1, *out.shape[-2:])
dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k,
cu_seqlens_q, cu_seqlens_k,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout="bshd_bshd_bshd",
attn_mask_type="causal",
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
......@@ -646,55 +755,138 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, 0]
_flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, True,
**fa_backward_kwargs,
rng_state=ctx.rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
)
elif i >= (cp_size-rank-1):
if ctx.use_fused_attention:
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_ = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous()
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
out_ = out.view(out.shape[0], -1, *out.shape[-2:])
dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k//2,
cu_seqlens_q, cu_seqlens_k//2,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout="bshd_bshd_bshd",
attn_mask_type="no_mask",
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1]
_flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k//2,
ctx.max_seqlen_q, ctx.max_seqlen_k//2,
ctx.dropout_p, ctx.softmax_scale, False,
**fa_backward_kwargs,
rng_state=ctx.rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
)
else:
# [b, sq//2, np, hn] -> [b*sq//2, np, hn]
if ctx.use_fused_attention:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_ = q[:, 1, ...].contiguous()
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
out_ = out[:, 1, ...].contiguous()
dout_ = dout[:, 1, ...].contiguous()
dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q//2, ctx.max_seqlen_k,
cu_seqlens_q//2, cu_seqlens_k,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
[softmax_lse_, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout="bshd_bshd_bshd",
attn_mask_type="no_mask",
)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, sq//2, np, hn] -> [b*sq//2, np, hn]
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1]
_flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse_,
dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k,
ctx.max_seqlen_q//2, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, False,
**fa_backward_kwargs,
rng_state=ctx.rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
)
else:
if ctx.use_fused_attention:
dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k,
cu_seqlens_q, cu_seqlens_k,
q, kv[0], kv[1], out, dout, TE_DType[q.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout="bshd_bshd_bshd",
attn_mask_type="no_mask",
)
else:
# [b, sq, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, sk, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, sq, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1]
_flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, False,
**fa_optional_backward_kwargs
)
if i >= (cp_size-rank-1):
# [b*sq, np, hn] -> [b, 2, sq//2, np, hn]
if i >= (cp_size-rank-1) or not ctx.causal:
# [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
# [b*sq, np, hn] -> [b, sq, np, hn] if not causal
dq_ = dq_.view(*dq.shape)
else:
# [b*sq//2, np, hn] -> [b, sq//2, np, hn]
dq_ = dq_.view(dq.shape[0], *dq.shape[2:])
if ctx.causal:
if i > (cp_size-rank-1):
dq.add_(dq_)
elif i == (cp_size-rank-1):
......@@ -707,19 +899,28 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
dq[:, 1, ...].add_(dq_)
else:
dq[:, 1, ...].copy_(dq_)
else:
if i == 0:
dq.copy_(dq_)
else:
dq.add_(dq_)
# wait until dKV is received
for req in send_recv_reqs:
req.wait()
dkv = p2p_comm_buffers[(i+1)%2][1]
if i >= (cp_size-rank-1) and i != (cp_size-1):
if ctx.use_fused_attention:
dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
if ctx.causal and i >= (cp_size-rank-1) and i != (cp_size-1):
# [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
else:
# [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn]
# [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal
# [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
dkv_ = dkv_.view(*dkv.shape)
if ctx.causal:
if i == (cp_size-1):
if rank == 0:
dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
......@@ -736,24 +937,32 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
else:
dkv.copy_(dkv_)
else:
assert False, "Not implemented yet!"
if i == 0:
dkv.copy_(dkv_)
else:
dkv.add_(dkv_)
if ctx.causal:
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
dq = dq.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
return dq, dkv[0], dkv[1], None, None, None, None, None, None, None, None, None, None, None
return None, dq, dkv[0], dkv[1], None, None, None, None, None, None, \
None, None, None, None, None, None
def flash_attn_forward_func_with_cp(q, k, v, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p,
cp_group, cp_global_ranks, cp_stream,
softmax_scale=None, causal=False,
deterministic=False):
"""Flash Attention implementation with context parallelism"""
out = FlashAttnUnpaddedFuncWithCP.apply(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
cp_group, cp_global_ranks, cp_stream, softmax_scale, causal, deterministic
def attn_forward_func_with_cp(
is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
cp_group, cp_global_ranks, cp_stream, softmax_scale=None, attn_mask_type="causal",
deterministic=False, use_fused_attention=False
) -> torch.Tensor:
"""Attention implementation with context parallelism"""
assert (attn_mask_type in ["causal", "no_mask"]
), f"Mask type of {attn_mask_type} is not supported with context parallelism!"
out = AttnFuncWithCP.apply(
is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, attn_mask_type,
deterministic, use_fused_attention
)
return out
......@@ -1343,7 +1552,7 @@ class FlashAttention(torch.nn.Module):
]
if 'padding' in attn_mask_type:
assert not context_parallel, "Padding mask not supported with context parallelism."
assert not context_parallel, "Padding mask not supported with context parallelism!"
if self.attention_type == "self":
assert (
......@@ -1402,7 +1611,7 @@ class FlashAttention(torch.nn.Module):
else:
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
elif qkv_format == 'thd':
assert not context_parallel, "thd format is not supported for context parallelism!"
assert not context_parallel, "thd format not supported with context parallelism!"
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
if max_seqlen_q is None:
......@@ -1420,13 +1629,13 @@ class FlashAttention(torch.nn.Module):
alibi_slopes is None
), "Alibi slope bias addition is not supported with context parallelism."
with self.attention_dropout_ctx():
output = flash_attn_forward_func_with_cp(
query_layer, key_layer, value_layer,
output = attn_forward_func_with_cp(
self.training, query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor,
causal="causal" in attn_mask_type,
attn_mask_type=attn_mask_type,
deterministic=self.deterministic
)
else:
......@@ -1755,6 +1964,9 @@ class FusedAttention(torch.nn.Module):
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
) -> torch.Tensor:
"""fused attention fprop"""
......@@ -1773,6 +1985,8 @@ class FusedAttention(torch.nn.Module):
qkv_layout in QKVLayouts
), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
assert (
qkv_format != 'thd'
......@@ -1786,6 +2000,8 @@ class FusedAttention(torch.nn.Module):
batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
if 'padding' in attn_mask_type:
assert not context_parallel, "Padding mask not supported with context parallelism!"
global _cu_seqlens_q, _cu_seqlens_kv
if (cu_seqlens_q is not None and cu_seqlens_kv is not None):
# use cu_seqlens when both cu_seqlens and attention_mask are present
......@@ -1829,6 +2045,31 @@ class FusedAttention(torch.nn.Module):
and (core_attention_bias_type == "no_bias")
and (fused_attention_backend
== tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen))
if context_parallel:
assert (fused_attention_backend
== tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
), f"{fused_attention_backend} does not work with context parallelism!"
assert (core_attention_bias_type == "no_bias"), \
"Core attention bias has not been supported with context parallelism yet!"
if qkv_format == 'sbhd':
query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
for x in (query_layer, key_layer, value_layer)]
with self.attention_dropout_ctx():
output = attn_forward_func_with_cp(
self.training,
query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv,
max_seqlen_q, max_seqlen_kv,
self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor,
attn_mask_type=attn_mask_type,
use_fused_attention=True,
)
if qkv_format == 'sbhd':
output = output.transpose(0,1).contiguous()
else:
with self.attention_dropout_ctx():
output = FusedAttnFunc.apply(
self.training,
......@@ -2315,12 +2556,13 @@ class DotProductAttention(torch.nn.Module):
if core_attention_bias_type != "no_bias" or core_attention_bias is not None:
use_flash_attention = False
context_parallel = (self.cp_group is not None and \
get_distributed_world_size(self.cp_group) != 1)
# Filter: sliding window attention.
# UnfusedDotProductAttention can support SWA via arbitrary attention mask.
if window_size not in ((-1, -1), (-1, 0)):
use_fused_attention = False
context_parallel = (self.cp_group is not None
and get_distributed_world_size(self.cp_group) != 1)
if (not _flash_attn_2_3_plus) or context_parallel:
use_flash_attention = False
......@@ -2361,8 +2603,10 @@ class DotProductAttention(torch.nn.Module):
# DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = (fused_attention_backend in
[FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]])
use_fused_attention = (use_fused_attention
and is_backend_avail)
use_fused_attention = ( \
use_fused_attention and is_backend_avail and \
(not context_parallel or \
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]))
# Filter: Alibi slopes
if alibi_slopes is not None:
......@@ -2415,43 +2659,52 @@ class DotProductAttention(torch.nn.Module):
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv)
assert (
self.cp_group is None or get_distributed_world_size(self.cp_group) == 1
), "Context parallelism is only implemented with Flash Attention!"
if use_fused_attention:
if _NVTE_DEBUG:
print("[DotProductAttention]: using cuDNN fused attention (backend "
+ str(int(fused_attention_backend)) + ")")
if checkpoint_core_attention:
return self._checkpointed_attention_forward(self.fused_attention,
return self._checkpointed_attention_forward(
self.fused_attention,
query_layer,
key_layer,
value_layer,
qkv_layout = qkv_layout,
cu_seqlens_q = cu_seqlens_q,
cu_seqlens_kv = cu_seqlens_kv,
attn_mask_type = attn_mask_type,
attention_mask = attention_mask,
fused_attention_backend = fused_attention_backend,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
fast_zero_fill = fast_zero_fill,
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
fast_zero_fill=fast_zero_fill,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv)
return self.fused_attention(query_layer, key_layer, value_layer,
qkv_layout = qkv_layout,
cu_seqlens_q = cu_seqlens_q,
cu_seqlens_kv = cu_seqlens_kv,
attn_mask_type = attn_mask_type,
attention_mask = attention_mask,
fused_attention_backend = fused_attention_backend,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
fast_zero_fill = fast_zero_fill,
return self.fused_attention(
query_layer,
key_layer,
value_layer,
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
fast_zero_fill=fast_zero_fill,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv)
assert (not context_parallel), \
"Context parallelism is only implemented with Flash Attention and Fused Attention!"
if _NVTE_DEBUG:
print("[DotProductAttention]: using unfused DPA")
if use_unfused_attention:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment