Unverified Commit 9709147e authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Add attention bias and qkv format to context parallelism (#726)



* make FusedAttn with CP support bias
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert Alibi cannot work with CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* syntax fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix variable name
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix tensor shapes
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* a typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix bias indexing for CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add attn bias tests
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change dbias update location
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix CP test model configs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change CP test sequence length
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make AttnFuncWithCP support qkv format of sbhd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make sure qkv are contiguous for CP in cuDNN fused attn
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change assert message
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix code format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent f85553ea
......@@ -6,7 +6,7 @@ import os, sys
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
from test_fused_attn_with_cp import model_configs
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16}
......@@ -17,8 +17,10 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
......@@ -40,8 +42,6 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
assert(rank in cp_comm_ranks)
cp_comm_group = dist.new_group(cp_comm_ranks, backend='nccl')
config = model_configs[model]
assert config.attn_mask_type in ['causal', 'no_mask'], f"{config.attn_mask_type} is an unsupported attention mask type!"
# instantiate core attn module
......@@ -69,18 +69,30 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
# create flash attention bias
if config.attn_bias_type not in ["no_bias", "alibi"]:
attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv)
bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda()
else:
bias = None
# make sure all GPU ranks have same inputs
for x in [q, k, v, dout]:
for x in [q, k, v, dout] + ([] if bias is None else [bias]):
dist.broadcast(x, 0, group=cp_comm_group)
# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
out = core_attn(q, k, v)
out = core_attn(
q, k, v,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
)
out.backward(dout)
# run core_attn wit CP
q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]]
q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])]
bias_ = rest[0] if len(rest) else None
seq_dim = qkv_format.index('s')
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q_, k_, v_, dout_]]
......@@ -88,8 +100,16 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None:
bias_ = bias_.view(*bias_.shape[:-2], 2*world_size, bias_.shape[-2]//(2*world_size), bias_.shape[-1])
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
out_ = core_attn(q_, k_, v_)
out_ = core_attn(
q_, k_, v_,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias_,
)
out_.backward(dout_)
for x in [out_, q_.grad, k_.grad, v_.grad]:
......
......@@ -11,12 +11,12 @@ from test_fused_attn import (
_cudnn_version,
)
model_configs = {
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(1, 12, 12, 128, 16384, 16384, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(1, 12, 12, 128, 16384, 16384, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(1, 12, 1, 128, 16384, 16384, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(1, 12, 1, 128, 16384, 16384, 0.0, "no_mask", "no_bias"), # GQA
"cp_1_0": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
}
def get_bash_arguments(**kwargs):
......@@ -30,7 +30,7 @@ def get_bash_arguments(**kwargs):
@pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
def test_cp_with_flash_attention(dtype, model, qkv_format):
subprocess.run(
......@@ -43,9 +43,21 @@ def test_cp_with_flash_attention(dtype, model, qkv_format):
check=True
)
model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_2_0": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
}
@pytest.mark.skipif(_cudnn_version() < (8,9,7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
def test_cp_with_fused_attention(dtype, model, qkv_format):
subprocess.run(
......
......@@ -490,9 +490,10 @@ def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
@jit_fuser
def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_per_step):
def flash_attn_fwd_out_correction(out, out_per_step, seq_dim,
softmax_lse, softmax_lse_per_step):
"""Merge partial outputs of each step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).transpose(1, 2)
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step*softmax_lse_corrected_exp
out.add_(out_corrected)
......@@ -516,22 +517,44 @@ class AttnFuncWithCP(torch.autograd.Function):
@staticmethod
def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, attn_mask_type,
deterministic, use_fused_attention):
dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format,
attn_mask_type, attn_bias_type, attn_bias, deterministic, use_fused_attention):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
cp_size = get_distributed_world_size(cp_group)
rank = get_distributed_rank(cp_group)
send_dst = cp_global_ranks[(rank + 1) % cp_size]
recv_src = cp_global_ranks[(rank + cp_size - 1) % cp_size]
recv_src = cp_global_ranks[(rank - 1) % cp_size]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
causal = (attn_mask_type == "causal")
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
if causal:
if qkv_format == "bshd":
# [b, s, np, hn] -> [b, 2, s//2, np, hn]
q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]]
elif qkv_format == "sbhd":
# [s, b, np, hn] -> [2, s//2, b, np, hn]
q, k, v = [x.view(2, x.shape[0]//2, *x.shape[1:]) for x in [q, k, v]]
if attn_bias is not None:
assert (len(attn_bias.shape) == 4), (
"Only support bias shape of [b, h, sq, sk] for forward, "
"and [1, h, sq, sk] for backward!"
)
# [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
attn_bias_ = attn_bias.view( \
*attn_bias.shape[:-2], \
2, attn_bias.shape[-2]//2, \
2*cp_size, attn_bias.shape[-1]//(2*cp_size) \
)
# [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)]
attn_bias = attn_bias.view( \
*attn_bias.shape[:-1], \
2*cp_size, attn_bias.shape[-1]//(2*cp_size) \
)
assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8"
fa_optional_forward_kwargs = {}
if _flash_attn_2_3_plus:
......@@ -542,10 +565,12 @@ class AttnFuncWithCP(torch.autograd.Function):
# Flash Attn inputs
q_inputs = [None, None]
kv_inputs = [None, None]
attn_bias_inputs = [None, None]
# Flash Attn outputs
out_per_step = [None for _ in range(cp_size)]
softmax_lse_per_step = [None for _ in range(cp_size)]
rng_states = [None for _ in range(cp_size)]
attn_biases = [None for _ in range(cp_size)]
# create two streams to resolve wave quantization issue of Flash Attn in each step
flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
......@@ -577,20 +602,37 @@ class AttnFuncWithCP(torch.autograd.Function):
if causal:
if i == 0:
if use_fused_attention:
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(
2, k.shape[0], -1, *k.shape[-2:])
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-3:])
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(
2, -1, *k.shape[-3:])
if attn_bias is not None:
idx = (rank - i) % cp_size
attn_bias_inputs[i%2] = torch.cat(
(attn_bias[..., idx, :], \
attn_bias[..., (2*cp_size-idx-1), :]),
dim=-1
).contiguous()
out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \
fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="causal",
qkv_layout=qkv_layout, attn_mask_type="causal",
attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],
)
if len(rest) > 0:
attn_biases[i] = rest[0]
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
......@@ -605,19 +647,31 @@ class AttnFuncWithCP(torch.autograd.Function):
)
elif i <= rank:
if use_fused_attention:
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-3:])
# [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, 0, ...].contiguous()
if attn_bias is not None:
idx = (rank - i) % cp_size
attn_bias_inputs[i%2] = attn_bias[..., idx, :].contiguous()
out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \
fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_k//2, cu_seqlens_q,
cu_seqlens_k//2, q_inputs[i%2], kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
qkv_layout=qkv_layout, attn_mask_type="no_mask",
attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],
)
if len(rest) > 0:
attn_biases[i] = rest[0]
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
......@@ -636,20 +690,37 @@ class AttnFuncWithCP(torch.autograd.Function):
)
else:
if use_fused_attention:
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_inputs[i%2] = q[:, 1, ...].contiguous()
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(
2, k.shape[0], -1, *k.shape[-2:])
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_inputs[i%2] = q[1].contiguous()
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(
2, -1, *k.shape[-3:])
if attn_bias is not None:
idx = (rank - i) % cp_size
attn_bias_inputs[i%2] = torch.cat(
(attn_bias_[..., 1, :, idx, :], \
attn_bias_[..., 1, :, (2*cp_size-idx-1), :]),
dim=-1
).contiguous()
out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \
fused_attn_fwd(
is_training, max_seqlen_q//2, max_seqlen_k, cu_seqlens_q//2,
cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
qkv_layout=qkv_layout, attn_mask_type="no_mask",
attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],
)
if len(rest) > 0:
attn_biases[i] = rest[0]
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
......@@ -666,15 +737,24 @@ class AttnFuncWithCP(torch.autograd.Function):
)
else:
if use_fused_attention:
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
if attn_bias is not None:
idx = (rank - i) % cp_size
attn_bias_inputs[i%2] = torch.cat(
(attn_bias[..., idx, :], attn_bias[..., (2*cp_size-idx-1), :]),
dim=-1
).contiguous()
out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \
fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
cu_seqlens_k, q, kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
qkv_layout=qkv_layout, attn_mask_type="no_mask",
attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],
)
if len(rest) > 0:
attn_biases[i] = rest[0]
else:
# [b, sq, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
......@@ -719,23 +799,33 @@ class AttnFuncWithCP(torch.autograd.Function):
torch.cuda.current_stream().wait_stream(flash_attn_streams[1])
softmax_lse = softmax_lse.to(torch.float)
seq_dim = qkv_format.index("s")
for i in range(cp_size):
# [b*sq, np, hn] -> [b, sq, np, hn] or [b*sq//2, np, hn] -> [b, sq//2, np, hn]
out_ = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
if qkv_format == "bshd":
out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
out_ = out[:, 1, ...]
elif qkv_format == "sbhd":
out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:])
out_ = out[1]
if i <= rank or not causal:
flash_attn_fwd_out_correction(out.view(*out_.shape),
out_,
flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape),
out_per_step[i],
seq_dim,
softmax_lse,
softmax_lse_per_step[i])
else:
flash_attn_fwd_out_correction(out[:, 1, ...],
out_,
flash_attn_fwd_out_correction(out_,
out_per_step[i],
seq_dim,
softmax_lse_[..., 1, :],
softmax_lse_per_step[i])
kv = p2p_comm_buffers[-1]
if use_fused_attention:
if qkv_format == "bshd":
out = out.view(out.shape[0], -1, *out.shape[-2:])
elif qkv_format == "sbhd":
out = out.view(-1, *out.shape[-3:])
else:
out = out.view(-1, *out.shape[-2:])
ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k)
......@@ -747,6 +837,10 @@ class AttnFuncWithCP(torch.autograd.Function):
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.qkv_format = qkv_format
ctx.attn_bias_type = attn_bias_type
ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
ctx.attn_biases = attn_biases
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
return out
......@@ -757,10 +851,26 @@ class AttnFuncWithCP(torch.autograd.Function):
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank + cp_size - 1) % cp_size]
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
if ctx.attn_biases[0] is not None:
# [b, np, sq, 2*cp, sk//(2*cp)]
attn_dbias = torch.zeros(
*ctx.attn_bias_shape,
dtype=ctx.attn_biases[0].dtype,
device=ctx.attn_biases[0].device
)
# [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
attn_dbias_ = attn_dbias.view(
*attn_dbias.shape[:-3], 2, attn_dbias.shape[-3]//2, *attn_dbias.shape[-2:]
)
else:
attn_dbias = None
if ctx.causal:
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
......@@ -814,6 +924,7 @@ class AttnFuncWithCP(torch.autograd.Function):
if ctx.causal:
if i == (cp_size-1):
if ctx.use_fused_attention:
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_ = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
......@@ -821,16 +932,28 @@ class AttnFuncWithCP(torch.autograd.Function):
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
out_ = out.view(out.shape[0], -1, *out.shape[-2:])
dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
dq_, dk_, dv_, _ = fused_attn_bwd(
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_ = q.view(-1, *q.shape[-3:])
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
out_ = out.view(-1, *out.shape[-3:])
dout_ = dout.view(-1, *dout.shape[-3:])
aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]]
if attn_dbias is not None:
aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k,
cu_seqlens_q, cu_seqlens_k,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], TE_DType[kv.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]],
q_, kv_[0], kv_[1], out_, dout_,
TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout="bshd_bshd_bshd",
qkv_layout=qkv_layout,
attn_mask_type="causal",
attn_bias_type=ctx.attn_bias_type,
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
......@@ -854,6 +977,7 @@ class AttnFuncWithCP(torch.autograd.Function):
)
elif i >= (cp_size-rank-1):
if ctx.use_fused_attention:
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_ = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
......@@ -861,16 +985,28 @@ class AttnFuncWithCP(torch.autograd.Function):
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
out_ = out.view(out.shape[0], -1, *out.shape[-2:])
dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
dq_, dk_, dv_, _ = fused_attn_bwd(
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_ = q.view(-1, *q.shape[-3:])
# [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn]
kv_ = kv[:, 0, ...].contiguous()
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
out_ = out.view(-1, *out.shape[-3:])
dout_ = dout.view(-1, *dout.shape[-3:])
aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]]
if attn_dbias is not None:
aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k//2,
cu_seqlens_q, cu_seqlens_k//2,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], TE_DType[kv.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]],
q_, kv_[0], kv_[1], out_, dout_,
TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout="bshd_bshd_bshd",
qkv_layout=qkv_layout,
attn_mask_type="no_mask",
attn_bias_type=ctx.attn_bias_type,
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
......@@ -894,6 +1030,7 @@ class AttnFuncWithCP(torch.autograd.Function):
)
else:
if ctx.use_fused_attention:
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_ = q[:, 1, ...].contiguous()
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
......@@ -901,16 +1038,28 @@ class AttnFuncWithCP(torch.autograd.Function):
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
out_ = out[:, 1, ...].contiguous()
dout_ = dout[:, 1, ...].contiguous()
dq_, dk_, dv_, _ = fused_attn_bwd(
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_ = q[1].contiguous()
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
out_ = out[1].contiguous()
dout_ = dout[1].contiguous()
aux_ctx_tensors = [softmax_lse_, ctx.rng_states[cp_size-i-1]]
if attn_dbias is not None:
aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q//2, ctx.max_seqlen_k,
cu_seqlens_q//2, cu_seqlens_k,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], TE_DType[kv.dtype],
[softmax_lse_, ctx.rng_states[cp_size-i-1]],
q_, kv_[0], kv_[1], out_, dout_,
TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout="bshd_bshd_bshd",
qkv_layout=qkv_layout,
attn_mask_type="no_mask",
attn_bias_type=ctx.attn_bias_type,
)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
......@@ -934,16 +1083,20 @@ class AttnFuncWithCP(torch.autograd.Function):
)
else:
if ctx.use_fused_attention:
dq_, dk_, dv_, _ = fused_attn_bwd(
aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]]
if attn_dbias is not None:
aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k,
cu_seqlens_q, cu_seqlens_k,
q, kv[0], kv[1], out, dout, TE_DType[q.dtype], TE_DType[kv.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]],
q, kv[0], kv[1], out, dout,
TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout="bshd_bshd_bshd",
qkv_layout=qkv_layout,
attn_mask_type="no_mask",
attn_bias_type=ctx.attn_bias_type,
)
else:
# [b, sq, np, hn] -> [b*sq, np, hn]
......@@ -970,8 +1123,12 @@ class AttnFuncWithCP(torch.autograd.Function):
# [b*sq, np, hn] -> [b, sq, np, hn] if not causal
dq_ = dq_.view(*dq.shape)
else:
if ctx.qkv_format == "bshd":
# [b*sq//2, np, hn] -> [b, sq//2, np, hn]
dq_ = dq_.view(dq.shape[0], *dq.shape[2:])
elif ctx.qkv_format == "sbhd":
# [b*sq//2, np, hn] -> [sq//2, b, np, hn]
dq_ = dq_.view(-1, *dq.shape[-3:])
if ctx.causal:
if i > (cp_size-rank-1):
......@@ -980,18 +1137,44 @@ class AttnFuncWithCP(torch.autograd.Function):
if rank == (cp_size-1):
dq.copy_(dq_)
else:
if ctx.qkv_format == "bshd":
dq[:, 0, ...].copy_(dq_[:, 0, ...])
dq[:, 1, ...].add_(dq_[:, 1, ...])
elif ctx.qkv_format == "sbhd":
dq[0].copy_(dq_[0])
dq[1].add_(dq_[1])
elif i > 0:
if ctx.qkv_format == "bshd":
dq[:, 1, ...].add_(dq_)
elif ctx.qkv_format == "sbhd":
dq[1].add_(dq_)
else:
if ctx.qkv_format == "bshd":
dq[:, 1, ...].copy_(dq_)
elif ctx.qkv_format == "sbhd":
dq[1].copy_(dq_)
else:
if i == 0:
dq.copy_(dq_)
else:
dq.add_(dq_)
if attn_dbias is not None:
idx = (rank+i+1)%cp_size
if i == (cp_size - 1) or not ctx.causal:
# [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)]
dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1]//2)
attn_dbias[..., idx, :].copy_(dbias_[..., 0, :])
attn_dbias[..., (2*cp_size-idx-1), :].copy_(dbias_[..., 1, :])
elif i >= (cp_size-rank-1):
# [b, np, sq, sk//(2*cp)]
attn_dbias[..., idx, :].copy_(dbias_)
else:
# [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)]
dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1]//2)
attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :])
attn_dbias_[..., 1, :, (2*cp_size-idx-1), :].copy_(dbias_[..., 1, :])
# wait until dKV is received
for req in send_recv_reqs:
req.wait()
......@@ -1000,8 +1183,12 @@ class AttnFuncWithCP(torch.autograd.Function):
if ctx.use_fused_attention:
dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
if ctx.causal and i >= (cp_size-rank-1) and i != (cp_size-1):
if ctx.qkv_format == "bshd":
# [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
elif ctx.qkv_format == "sbhd":
# [2, b*sk//2, np, hn] -> [2, sk//2, b, np, hn]
dkv_ = dkv_.view(dkv.shape[0], -1, *dkv.shape[-3:])
else:
# [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal
# [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
......@@ -1010,15 +1197,25 @@ class AttnFuncWithCP(torch.autograd.Function):
if ctx.causal:
if i == (cp_size-1):
if rank == 0:
if ctx.qkv_format == "bshd":
dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].add_(dkv_[:, 0, ...])
dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
else:
dkv.add_(dkv_)
elif i >= (cp_size-rank-1):
if i == 0 and rank == (cp_size-1):
if ctx.qkv_format == "bshd":
dkv[:, :, 0, ...].copy_(dkv_)
elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].copy_(dkv_)
else:
if ctx.qkv_format == "bshd":
dkv[:, :, 0, ...].add_(dkv_)
elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].add_(dkv_)
elif i > 0:
dkv.add_(dkv_)
else:
......@@ -1030,26 +1227,44 @@ class AttnFuncWithCP(torch.autograd.Function):
dkv.add_(dkv_)
if ctx.causal:
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
dq = dq.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
dq = dq.view(-1, *q.shape[-3:])
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
dkv = dkv.view(kv.shape[0], -1, *kv.shape[-3:])
if attn_dbias is not None:
# [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)
return None, dq, dkv[0], dkv[1], None, None, None, None, None, None, \
None, None, None, None, None, None
None, None, None, None, None, None, attn_dbias, None, None
def attn_forward_func_with_cp(
is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
cp_group, cp_global_ranks, cp_stream, softmax_scale=None, attn_mask_type="causal",
deterministic=False, use_fused_attention=False
is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale=None, qkv_format="bshd",
attn_mask_type="causal", attn_bias_type="no_bias", attn_bias=None, deterministic=False,
use_fused_attention=False
) -> torch.Tensor:
"""Attention implementation with context parallelism"""
assert(qkv_format in ["bshd", "sbhd"]
), f"QKV format of {qkv_format} is not supported with context parallelism!"
assert(qkv_format != "sbhd" or use_fused_attention
), "FlashAttention does not support sbhd format!"
assert (attn_mask_type in ["causal", "no_mask"]
), f"Mask type of {attn_mask_type} is not supported with context parallelism!"
assert (attn_bias is None or use_fused_attention
), "Attention bias is only supported with FusedAttention!"
out = AttnFuncWithCP.apply(
is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, attn_mask_type,
deterministic, use_fused_attention
dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format,
attn_mask_type, attn_bias_type, attn_bias, deterministic, use_fused_attention
)
return out
......@@ -1857,6 +2072,7 @@ class FlashAttention(torch.nn.Module):
self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor,
qkv_format="bshd" if qkv_format=="sbhd" else qkv_format,
attn_mask_type=attn_mask_type,
deterministic=self.deterministic
)
......@@ -2821,10 +3037,10 @@ class FusedAttention(TransformerEngineBaseModule):
assert (fused_attention_backend
== tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
), f"{fused_attention_backend} does not work with context parallelism!"
assert (core_attention_bias_type == "no_bias"), \
"Core attention bias has not been supported with context parallelism yet!"
if qkv_format == 'sbhd':
query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
assert (
core_attention_bias_type not in ["alibi"]
), f"{core_attention_bias_type} is not supported with context parallelism!"
query_layer, key_layer, value_layer = [x.contiguous()
for x in (query_layer, key_layer, value_layer)]
with self.attention_dropout_ctx():
output = attn_forward_func_with_cp(
......@@ -2835,11 +3051,12 @@ class FusedAttention(TransformerEngineBaseModule):
self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
attn_bias_type=core_attention_bias_type,
attn_bias=core_attention_bias,
use_fused_attention=True,
)
if qkv_format == 'sbhd':
output = output.transpose(0,1).contiguous()
else:
with self.prepare_forward(query_layer,
is_first_microbatch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment