Unverified Commit f68df153 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

[PyTorch] Add support for cuDNN FusedAttention + THD + CP (#885)



* add seq_offsets_qkvo for cudnn thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add seq_offsets_qkvo to AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix seq_offsets calculation of cudnn thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove a thd assert
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix bias for thd test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add thd test for cudnn FA with CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* skip GQA/MQA test for cuDNN THD
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix seq_offsets inputs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove two comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn mask type for cudnn thd with cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn_mask_type check
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn_mask_type for cudnn fa with thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix a typo
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix out dout in bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert cudnn+thd does not support attn bias
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* check if attn_mask_type has padding
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change cp test batch size to 2
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix code format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix two assert info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comment
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 90f3c9ad
......@@ -22,6 +22,8 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
if qkv_format == 'thd' and (config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias"):
return
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
......@@ -45,6 +47,12 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
assert config.attn_mask_type in ['causal', 'no_mask'], f"{config.attn_mask_type} is an unsupported attention mask type!"
if kernel_backend == 'FusedAttention' and qkv_format == 'thd':
if 'causal' in config.attn_mask_type:
config.attn_mask_type = 'padding_causal'
else:
config.attn_mask_type = 'padding'
# instantiate core attn module
core_attn = DotProductAttention(config.num_heads,
config.head_dim,
......@@ -112,9 +120,9 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
out.backward(dout)
# run core_attn wit CP
q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])]
bias_ = rest[0] if len(rest) else None
if qkv_format == "bshd" or qkv_format == "sbhd":
q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])]
bias_ = rest[0] if len(rest) else None
seq_dim = qkv_format.index('s')
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q_, k_, v_, dout_]]
......@@ -122,14 +130,12 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
elif qkv_format == "thd":
q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]]
seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank)
seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank)
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
cu_seqlens_q = cu_seqlens_q // world_size
cu_seqlens_kv = cu_seqlens_kv // world_size
bias_ = None
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
......@@ -158,7 +164,10 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
# compare results with and without CP
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == 'bf16':
tols = dict(atol=2.5e-2, rtol=2.5e-2)
if config.num_heads == config.num_gqa_groups:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=3.5e-2, rtol=3.5e-2)
if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
......
......@@ -14,10 +14,10 @@ from transformer_engine.pytorch.utils import get_device_compute_capability
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
}
def get_bash_arguments(**kwargs):
......@@ -47,21 +47,21 @@ def test_cp_with_flash_attention(dtype, model, qkv_format):
model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_2_0": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
}
@pytest.mark.skipif(_cudnn_version() < (8,9,7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd', 'thd'])
def test_cp_with_fused_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
......
......@@ -520,8 +520,9 @@ class AttnFuncWithCP(torch.autograd.Function):
@staticmethod
def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format,
attn_mask_type, attn_bias_type, attn_bias, deterministic, use_fused_attention):
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, dropout_p,
cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format, attn_mask_type,
attn_bias_type, attn_bias, deterministic, use_fused_attention):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -531,7 +532,8 @@ class AttnFuncWithCP(torch.autograd.Function):
recv_src = cp_global_ranks[(rank - 1) % cp_size]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
causal = (attn_mask_type == "causal")
causal = ("causal" in attn_mask_type)
padding = ("padding" in attn_mask_type)
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
......@@ -617,6 +619,8 @@ class AttnFuncWithCP(torch.autograd.Function):
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(
2, -1, *k.shape[-3:])
elif qkv_format == "thd":
q_inputs[i%2] = q
if attn_bias is not None:
idx = (rank - i) % cp_size
attn_bias_inputs[i%2] = torch.cat(
......@@ -631,8 +635,10 @@ class AttnFuncWithCP(torch.autograd.Function):
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout=qkv_layout, attn_mask_type="causal",
qkv_layout=qkv_layout, attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],
seq_offsets_q=seq_offsets_q, seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v, seq_offsets_o=seq_offsets_o,
)
if len(rest) > 0:
attn_biases[i] = rest[0]
......@@ -660,6 +666,11 @@ class AttnFuncWithCP(torch.autograd.Function):
q_inputs[i%2] = q.view(-1, *q.shape[-3:])
# [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, 0, ...].contiguous()
elif qkv_format == "thd":
q_inputs[i%2] = q
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_inputs[i%2] = tex.thd_read_half_tensor(
kv_inputs[i%2], cu_seqlens_k, 0)
if attn_bias is not None:
idx = (rank - i) % cp_size
attn_bias_inputs[i%2] = attn_bias[..., idx, :].contiguous()
......@@ -669,9 +680,18 @@ class AttnFuncWithCP(torch.autograd.Function):
cu_seqlens_k//2, q_inputs[i%2], kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout=qkv_layout, attn_mask_type="no_mask",
attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],
attn_scale=softmax_scale,
dropout=dropout_p,
qkv_layout=qkv_layout,
attn_mask_type="padding" if padding else "no_mask",
attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i%2],
seq_offsets_q=seq_offsets_q,
seq_offsets_k=None if seq_offsets_k is None \
else seq_offsets_k//2,
seq_offsets_v=None if seq_offsets_v is None \
else seq_offsets_v//2,
seq_offsets_o=seq_offsets_o,
)
if len(rest) > 0:
attn_biases[i] = rest[0]
......@@ -710,6 +730,9 @@ class AttnFuncWithCP(torch.autograd.Function):
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(
2, -1, *k.shape[-3:])
elif qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_inputs[i%2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
if attn_bias is not None:
idx = (rank - i) % cp_size
attn_bias_inputs[i%2] = torch.cat(
......@@ -723,9 +746,18 @@ class AttnFuncWithCP(torch.autograd.Function):
cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout=qkv_layout, attn_mask_type="no_mask",
attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],
attn_scale=softmax_scale,
dropout=dropout_p,
qkv_layout=qkv_layout,
attn_mask_type="padding" if padding else "no_mask",
attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i%2],
seq_offsets_q=None if seq_offsets_q is None \
else seq_offsets_q//2,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=None if seq_offsets_o is None \
else seq_offsets_o//2,
)
if len(rest) > 0:
attn_biases[i] = rest[0]
......@@ -763,8 +795,10 @@ class AttnFuncWithCP(torch.autograd.Function):
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout=qkv_layout, attn_mask_type="no_mask",
qkv_layout=qkv_layout, attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],
seq_offsets_q=seq_offsets_q, seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v, seq_offsets_o=seq_offsets_o,
)
if len(rest) > 0:
attn_biases[i] = rest[0]
......@@ -870,16 +904,19 @@ class AttnFuncWithCP(torch.autograd.Function):
else:
out = out.view(-1, *out.shape[-2:])
ctx.save_for_backward(q, kv, out, softmax_lse,
cu_seqlens_q, cu_seqlens_k, *rng_states, *attn_biases)
ctx.save_for_backward(
q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
*rng_states, *attn_biases
)
ctx.cp_group = cp_group
ctx.cp_global_ranks = cp_global_ranks
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.qkv_format = qkv_format
ctx.attn_mask_type = attn_mask_type
ctx.attn_bias_type = attn_bias_type
ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
ctx.deterministic = deterministic
......@@ -889,15 +926,18 @@ class AttnFuncWithCP(torch.autograd.Function):
@staticmethod
def backward(ctx, dout):
(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) = ctx.saved_tensors[:6]
(seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o) = ctx.saved_tensors[6:10]
cp_size = get_distributed_world_size(ctx.cp_group)
rng_states = ctx.saved_tensors[6:6+cp_size]
attn_biases = ctx.saved_tensors[6+cp_size:6+cp_size*2]
rng_states = ctx.saved_tensors[10:10+cp_size]
attn_biases = ctx.saved_tensors[10+cp_size:10+cp_size*2]
rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
causal = ("causal" in ctx.attn_mask_type)
padding = ("padding" in ctx.attn_mask_type)
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
if attn_biases[0] is not None:
......@@ -914,7 +954,7 @@ class AttnFuncWithCP(torch.autograd.Function):
else:
attn_dbias = None
if ctx.causal:
if causal:
if ctx.qkv_format == "thd":
softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0))
else:
......@@ -969,7 +1009,7 @@ class AttnFuncWithCP(torch.autograd.Function):
kv = p2p_comm_buffers[i%2][0]
# In reversed order of fwd
if ctx.causal:
if causal:
if i == (cp_size-1):
if ctx.use_fused_attention:
if ctx.qkv_format == "bshd":
......@@ -988,6 +1028,8 @@ class AttnFuncWithCP(torch.autograd.Function):
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
out_ = out.view(-1, *out.shape[-3:])
dout_ = dout.view(-1, *dout.shape[-3:])
elif ctx.qkv_format == "thd":
q_, kv_, out_, dout_ = q, kv, out, dout
aux_ctx_tensors = [softmax_lse, rng_states[cp_size-i-1]]
if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size-i-1]]
......@@ -997,10 +1039,11 @@ class AttnFuncWithCP(torch.autograd.Function):
q_, kv_[0], kv_[1], out_, dout_,
TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout=qkv_layout,
attn_mask_type="causal",
attn_mask_type=ctx.attn_mask_type,
attn_bias_type=ctx.attn_bias_type,
)
else:
......@@ -1041,6 +1084,10 @@ class AttnFuncWithCP(torch.autograd.Function):
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
out_ = out.view(-1, *out.shape[-3:])
dout_ = dout.view(-1, *dout.shape[-3:])
elif ctx.qkv_format == "thd":
q_, out_, dout_ = q, out, dout
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0)
aux_ctx_tensors = [softmax_lse, rng_states[cp_size-i-1]]
if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size-i-1]]
......@@ -1050,10 +1097,12 @@ class AttnFuncWithCP(torch.autograd.Function):
q_, kv_[0], kv_[1], out_, dout_,
TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
seq_offsets_q, None if seq_offsets_k is None else seq_offsets_k//2,
None if seq_offsets_v is None else seq_offsets_v//2, seq_offsets_o,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout=qkv_layout,
attn_mask_type="no_mask",
attn_mask_type="padding" if padding else "no_mask",
attn_bias_type=ctx.attn_bias_type,
)
else:
......@@ -1098,6 +1147,12 @@ class AttnFuncWithCP(torch.autograd.Function):
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
out_ = out[1].contiguous()
dout_ = dout[1].contiguous()
elif ctx.qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1)
dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1)
kv_ = kv
aux_ctx_tensors = [softmax_lse_, rng_states[cp_size-i-1]]
if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size-i-1]]
......@@ -1107,10 +1162,12 @@ class AttnFuncWithCP(torch.autograd.Function):
q_, kv_[0], kv_[1], out_, dout_,
TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
None if seq_offsets_q is None else seq_offsets_q//2, seq_offsets_k,
seq_offsets_v, None if seq_offsets_o is None else seq_offsets_o//2,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout=qkv_layout,
attn_mask_type="no_mask",
attn_mask_type="padding" if padding else "no_mask",
attn_bias_type=ctx.attn_bias_type,
)
else:
......@@ -1152,10 +1209,11 @@ class AttnFuncWithCP(torch.autograd.Function):
q, kv[0], kv[1], out, dout,
TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout=qkv_layout,
attn_mask_type="no_mask",
attn_mask_type=ctx.attn_mask_type,
attn_bias_type=ctx.attn_bias_type,
)
else:
......@@ -1178,7 +1236,7 @@ class AttnFuncWithCP(torch.autograd.Function):
**fa_optional_backward_kwargs
)
if i >= (cp_size-rank-1) or not ctx.causal:
if i >= (cp_size-rank-1) or not causal:
# [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
# [b*sq, np, hn] -> [b, sq, np, hn] if not causal
dq_ = dq_.view(*dq.shape)
......@@ -1190,7 +1248,7 @@ class AttnFuncWithCP(torch.autograd.Function):
# [b*sq//2, np, hn] -> [sq//2, b, np, hn]
dq_ = dq_.view(-1, *dq.shape[-3:])
if ctx.causal:
if causal:
if i > (cp_size-rank-1):
dq.add_(dq_)
elif i == (cp_size-rank-1):
......@@ -1227,7 +1285,7 @@ class AttnFuncWithCP(torch.autograd.Function):
if attn_dbias is not None:
idx = (rank+i+1)%cp_size
if i == (cp_size - 1) or not ctx.causal:
if i == (cp_size - 1) or not causal:
# [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)]
dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1]//2)
attn_dbias[..., idx, :].copy_(dbias_[..., 0, :])
......@@ -1248,7 +1306,7 @@ class AttnFuncWithCP(torch.autograd.Function):
dkv = p2p_comm_buffers[(i+1)%2][1]
if ctx.use_fused_attention:
dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
if ctx.causal and i >= (cp_size-rank-1) and i != (cp_size-1):
if causal and i >= (cp_size-rank-1) and i != (cp_size-1):
if ctx.qkv_format == "bshd":
# [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
......@@ -1260,7 +1318,7 @@ class AttnFuncWithCP(torch.autograd.Function):
# [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
dkv_ = dkv_.view(*dkv.shape)
if ctx.causal:
if causal:
if i == (cp_size-1):
if rank == 0:
if ctx.qkv_format == "bshd":
......@@ -1298,7 +1356,7 @@ class AttnFuncWithCP(torch.autograd.Function):
else:
dkv.add_(dkv_)
if ctx.causal:
if causal:
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
dq = dq.view(q.shape[0], -1, *q.shape[-2:])
......@@ -1314,13 +1372,14 @@ class AttnFuncWithCP(torch.autograd.Function):
# [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)
return None, dq, dkv[0], dkv[1], None, None, None, None, None, None, \
None, None, None, None, None, None, attn_dbias, None, None
return None, dq, dkv[0], dkv[1], None, None, None, None, None, None, None, None, \
None, None, None, None, None, None, None, None, attn_dbias, None, None
def attn_forward_func_with_cp(
is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale=None, qkv_format="bshd",
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, dropout_p,
cp_group, cp_global_ranks, cp_stream, softmax_scale=None, qkv_format="bshd",
attn_mask_type="causal", attn_bias_type="no_bias", attn_bias=None, deterministic=False,
use_fused_attention=False
) -> torch.Tensor:
......@@ -1329,16 +1388,19 @@ def attn_forward_func_with_cp(
), f"QKV format of {qkv_format} is not supported with context parallelism!"
assert(qkv_format != "sbhd" or use_fused_attention
), "FlashAttention does not support sbhd format!"
assert(not(qkv_format == "thd" and use_fused_attention)
), "FusedAttention does not support thd format!"
assert (attn_mask_type in ["causal", "no_mask"]
), f"Mask type of {attn_mask_type} is not supported with context parallelism!"
assert (attn_bias is None or use_fused_attention
), "Attention bias is only supported with FusedAttention!"
assert (qkv_format != 'thd' or \
not use_fused_attention or \
attn_mask_type in ["padding", "padding_causal"]
), f"""Context parallelism is not supported for {attn_mask_type} mask type and """ \
f"""{qkv_format} format with {"FusedAttention" if use_fused_attention else "FlashAttention"}!"""
assert (attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type)
), """Attention bias is only supported with FusedAttention and "causal" """ \
"""or "no_mask" mask types!"""
out = AttnFuncWithCP.apply(
is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format,
attn_mask_type, attn_bias_type, attn_bias, deterministic, use_fused_attention
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, dropout_p,
cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format, attn_mask_type,
attn_bias_type, attn_bias, deterministic, use_fused_attention
)
return out
......@@ -2147,6 +2209,7 @@ class FlashAttention(torch.nn.Module):
output = attn_forward_func_with_cp(
self.training, query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
None, None, None, None,
self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor,
......@@ -3131,7 +3194,6 @@ class FusedAttention(TransformerEngineBaseModule):
key_layer.device,
)
if qkv_format == 'thd':
assert not context_parallel, "thd format not supported with context parallelism!"
assert (max_seqlen_q is not None
and max_seqlen_kv is not None
and cu_seqlens_q is not None
......@@ -3140,8 +3202,10 @@ class FusedAttention(TransformerEngineBaseModule):
if (seq_offsets_q is None
or seq_offsets_k is None
or seq_offsets_v is None
or seq_offsets_o is None):
or seq_offsets_o is None
or context_parallel):
qkv_group = ''.join([x for x in qkv_layout if x not in 'bst'])
qkv_group = 'hd_hd_hd' if context_parallel else qkv_group
num_heads = query_layer.shape[-2]
num_gqa_groups = key_layer.shape[-2]
head_dim = query_layer.shape[-1]
......@@ -3181,6 +3245,7 @@ class FusedAttention(TransformerEngineBaseModule):
query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv,
max_seqlen_q, max_seqlen_kv,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment