Unverified Commit 479dbb73 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

add flash implementation with context parallelism (#362)



* add flash implementation with context parallelism
Signed-off-by: default avatarxren <xren@nvidia.com>

* next more comments
Signed-off-by: default avatarxren <xren@nvidia.com>

* code comment fix
Signed-off-by: default avatarxren <xren@nvidia.com>

* comment fix
Signed-off-by: default avatarxren <xren@nvidia.com>

* add missing space
Signed-off-by: default avatarxren <xren@nvidia.com>

* fix docstrings
Signed-off-by: default avatarxren <xren@nvidia.com>

* try to add fa v2 api
Signed-off-by: default avatarxren <xren@nvidia.com>

* fix a comment
Signed-off-by: default avatarxren <xren@nvidia.com>

* fix padded kv return
Signed-off-by: default avatarxren <xren@nvidia.com>

* add docstrings of context parallelism
Signed-off-by: default avatarxren <xren@nvidia.com>

* minor fix
Signed-off-by: default avatarxren <xren@nvidia.com>

* minor docstring fix
Signed-off-by: default avatarxren <xren@nvidia.com>

* fix positional arguments
Signed-off-by: default avatarxren <xren@nvidia.com>

* make docstring line shorter
Signed-off-by: default avatarxren <xren@nvidia.com>

* add fa v2 backward api for flash_attn_with_cp
Signed-off-by: default avatarxren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarxren <xren@nvidia.com>

* make sure hidden size per attn head is multiple of 8 for FA2
Signed-off-by: default avatarxren <xren@nvidia.com>

* remove an unnecessary assert check for FA2
Signed-off-by: default avatarxren <xren@nvidia.com>

* indention fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* Update FA version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarxren <xren@nvidia.com>
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent b95c1818
...@@ -290,7 +290,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -290,7 +290,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if "pytorch" in frameworks(): if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.2.1"]) add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.2.2"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks(): if "jax" in frameworks():
if not found_pybind11(): if not found_pybind11():
......
...@@ -42,6 +42,7 @@ from transformer_engine.pytorch.constants import ( ...@@ -42,6 +42,7 @@ from transformer_engine.pytorch.constants import (
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import ( from transformer_engine.pytorch.distributed import (
get_distributed_world_size, get_distributed_world_size,
get_distributed_rank,
checkpoint, checkpoint,
) )
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
...@@ -53,13 +54,436 @@ _flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2") ...@@ -53,13 +54,436 @@ _flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
if _flash_attn_2_available: if _flash_attn_2_available:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd # pylint: disable=no-name-in-module from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd # pylint: disable=no-name-in-module
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module
else: else:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_forward_func # pylint: disable=no-name-in-module,ungrouped-imports from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_forward_func # pylint: disable=no-name-in-module,ungrouped-imports
from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward
__all__ = ["DotProductAttention", "MultiheadAttention"] __all__ = ["DotProductAttention", "MultiheadAttention"]
def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
recv_tensor, recv_src,
cp_group, batch_p2p_comm):
"""Point-to-point communications of KV and dKV in Flash Attention with context parallelism"""
send_recv_ops = []
if batch_p2p_comm:
if rank % 2 == 0:
send_op = torch.distributed.P2POp(torch.distributed.isend,
send_tensor,
send_dst,
cp_group)
recv_op = torch.distributed.P2POp(torch.distributed.irecv,
recv_tensor,
recv_src,
cp_group)
send_recv_ops.append(send_op)
send_recv_ops.append(recv_op)
else:
recv_op = torch.distributed.P2POp(torch.distributed.irecv,
recv_tensor,
recv_src,
cp_group)
send_op = torch.distributed.P2POp(torch.distributed.isend,
send_tensor,
send_dst,
cp_group)
send_recv_ops.append(recv_op)
send_recv_ops.append(send_op)
send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
else:
if rank % 2 == 0:
send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
send_recv_ops.append(send_op)
send_recv_ops.append(recv_op)
else:
recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
send_recv_ops.append(recv_op)
send_recv_ops.append(send_op)
send_recv_reqs = send_recv_ops
return send_recv_reqs
@torch.jit.script
def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_per_step):
"""Merge partial outputs of each step in Flash Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).transpose(1, 2)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step*softmax_lse_corrected_exp
out.add_(out_corrected)
@torch.jit.script
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
"""Merge softmax stats of each step in Flash Attention with context parallelism"""
softmax_lse.exp_()
softmax_lse.add_(softmax_lse_per_step.to(torch.double).exp())
softmax_lse.log_()
class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
"""
Flash Attention implementation with context parallelism.
Split flash attention compute into multiple steps, and overlap current-step
compute with next-step communication.
"""
@staticmethod
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
cp_group, cp_global_ranks, cp_stream, softmax_scale, causal, deterministic):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
cp_size = get_distributed_world_size(cp_group)
rank = get_distributed_rank(cp_group)
send_dst = cp_global_ranks[(rank + 1) % cp_size]
recv_src = cp_global_ranks[(rank + cp_size - 1) % cp_size]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
# [b, s, np, hn] -> [b, 2, s//2, np, hn]
q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]]
if _flash_attn_2_available:
assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8"
# Flash Attn inputs
q_inputs = [None, None]
kv_inputs = [None, None]
# Flash Attn outputs
out_per_step = [None for _ in range(cp_size)]
softmax_lse_per_step = [None for _ in range(cp_size)]
rng_states = [None for _ in range(cp_size)]
# create two streams to resolve wave quantization issue of Flash Attn in each step
flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
# synchronize fwd results correction across steps
fwd_results_correction_done = torch.cuda.Event()
p2p_comm_buffers = [None for _ in range(cp_size)]
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
send_recv_reqs = [[], []]
for i in range(cp_size+1):
if i < cp_size:
with torch.cuda.stream(flash_attn_streams[i%2]):
# wait until KV is received
for req in send_recv_reqs[(i+1)%2]:
req.wait()
if i < (cp_size-1):
p2p_comm_buffers[i+1] = torch.empty_like(p2p_comm_buffers[i])
send_recv_reqs[i%2] = flash_attn_p2p_communicate(rank,
p2p_comm_buffers[i],
send_dst,
p2p_comm_buffers[i+1],
recv_src,
cp_group,
batch_p2p_comm)
kv_inputs[i%2] = p2p_comm_buffers[i]
if causal:
if i == 0:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_available:
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=True, return_softmax=False,
)
else:
out_per_step[i] = torch.empty_like(q_inputs[i%2])
_, softmax_lse_per_step[i], rng_states[i], _ = _flash_attn_forward( # pylint: disable=unbalanced-tuple-unpacking
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
out_per_step[i], cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale,
causal=True, return_softmax=False,
)
elif i <= rank:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_available:
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q, cu_seqlens_k//2, max_seqlen_q, max_seqlen_k//2,
dropout_p, softmax_scale, causal=False, return_softmax=False,
)
else:
out_per_step[i] = torch.empty_like(q_inputs[i%2])
_, softmax_lse_per_step[i], rng_states[i], _ = _flash_attn_forward( # pylint: disable=unbalanced-tuple-unpacking
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
out_per_step[i], cu_seqlens_q, cu_seqlens_k//2,
max_seqlen_q, max_seqlen_k//2, dropout_p, softmax_scale,
causal=False, return_softmax=False,
)
else:
# [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_available:
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q//2, cu_seqlens_k, max_seqlen_q//2, max_seqlen_k,
dropout_p, softmax_scale, causal=False, return_softmax=False,
)
else:
out_per_step[i] = torch.empty_like(q_inputs[i%2])
_, softmax_lse_per_step[i], rng_states[i], _ = _flash_attn_forward( # pylint: disable=unbalanced-tuple-unpacking
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
out_per_step[i], cu_seqlens_q//2, cu_seqlens_k,
max_seqlen_q//2, max_seqlen_k, dropout_p, softmax_scale,
causal=False, return_softmax=False,
)
else:
assert False, "Not implemented yet!"
if i > 0:
# wait until fwd restuls correction of last step is done
if i > 1:
flash_attn_streams[(i-1)%2].wait_event(fwd_results_correction_done)
with torch.cuda.stream(flash_attn_streams[(i-1)%2]):
if causal:
if i == 1:
out = torch.empty_like(q).zero_()
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(
*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2
)
elif (i-1) <= rank:
flash_attn_fwd_softmax_lse_correction(softmax_lse,
softmax_lse_per_step[i-1])
else:
flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :],
softmax_lse_per_step[i-1])
else:
assert False, "Not implemented yet!"
if i < cp_size:
flash_attn_streams[(i-1)%2].record_event(fwd_results_correction_done)
torch.cuda.current_stream().wait_stream(flash_attn_streams[1])
softmax_lse = softmax_lse.to(torch.float)
for i in range(cp_size):
# [b*sq, np, hn] -> [b, sq, np, hn] or [b*sq//2, np, hn] -> [b, sq//2, np, hn]
out_ = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
if i <= rank:
flash_attn_fwd_out_correction(out.view(*out_.shape),
out_,
softmax_lse,
softmax_lse_per_step[i])
else:
flash_attn_fwd_out_correction(out[:, 1, ...],
out_,
softmax_lse_[..., 1, :],
softmax_lse_per_step[i])
kv = p2p_comm_buffers[-1]
out = out.view(-1, *out.shape[-2:])
ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k)
ctx.rng_states = rng_states
ctx.cp_group = cp_group
ctx.cp_global_ranks = cp_global_ranks
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.deterministic = deterministic
return out
@staticmethod
def backward(ctx, dout):
q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank + cp_size - 1) % cp_size]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
# [b*sq, np, hn] -> [b, 2, sq//2, np, hn]
out = out.view(*q.shape)
dout = dout.view(*q.shape)
# Flash Attn outputs
dq = torch.empty_like(q)
p2p_comm_buffers = [torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), \
torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device)]
p2p_comm_buffers[0][0].copy_(kv)
send_recv_reqs = []
fa_optional_backward_kwargs = {}
if not _flash_attn_2_available:
fa_optional_backward_kwargs["num_splits"] = 1 if ctx.deterministic else 0
for i in range(cp_size):
# wait until KV is received
for req in send_recv_reqs:
req.wait()
send_tensor = p2p_comm_buffers[i%2]
recv_tensor = p2p_comm_buffers[(i+1)%2]
if i == 0:
send_tensor = send_tensor[0]
recv_tensor = recv_tensor[0]
if i == (cp_size-1):
send_tensor = send_tensor[1]
recv_tensor = recv_tensor[1]
send_recv_reqs = flash_attn_p2p_communicate(rank,
send_tensor,
send_dst,
recv_tensor,
recv_src,
ctx.cp_group,
batch_p2p_comm)
kv = p2p_comm_buffers[i%2][0]
# In reversed order of fwd
if ctx.causal:
if i == (cp_size-1):
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
_flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, True,
rng_state=ctx.rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
)
elif i >= (cp_size-rank-1):
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
_flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k//2,
ctx.max_seqlen_q, ctx.max_seqlen_k//2,
ctx.dropout_p, ctx.softmax_scale, False,
rng_state=ctx.rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
)
else:
# [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, sq//2, np, hn] -> [b*sq//2, np, hn]
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
_flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse_[..., 1, :],
dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k,
ctx.max_seqlen_q//2, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, False,
rng_state=ctx.rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
)
if i >= (cp_size-rank-1):
# [b*sq, np, hn] -> [b, 2, sq//2, np, hn]
dq_ = dq_.view(*dq.shape)
else:
# [b*sq//2, np, hn] -> [b, sq//2, np, hn]
dq_ = dq_.view(dq.shape[0], *dq.shape[2:])
if i > (cp_size-rank-1):
dq.add_(dq_)
elif i == (cp_size-rank-1):
if rank == (cp_size-1):
dq.copy_(dq_)
else:
dq[:, 0, ...].copy_(dq_[:, 0, ...])
dq[:, 1, ...].add_(dq_[:, 1, ...])
elif i > 0:
dq[:, 1, ...].add_(dq_)
else:
dq[:, 1, ...].copy_(dq_)
# wait until dKV is received
for req in send_recv_reqs:
req.wait()
dkv = p2p_comm_buffers[(i+1)%2][1]
if i >= (cp_size-rank-1) and i != (cp_size-1):
# [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
else:
# [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn]
dkv_ = dkv_.view(*dkv.shape)
if i == (cp_size-1):
if rank == 0:
dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
else:
dkv.add_(dkv_)
elif i >= (cp_size-rank-1):
if i == 0 and rank == (cp_size-1):
dkv[:, :, 0, ...].copy_(dkv_)
else:
dkv[:, :, 0, ...].add_(dkv_)
elif i > 0:
dkv.add_(dkv_)
else:
dkv.copy_(dkv_)
else:
assert False, "Not implemented yet!"
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
dq = dq.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
return dq, dkv[0], dkv[1], None, None, None, None, None, None, None, None, None, None, None
def flash_attn_forward_func_with_cp(q, k, v, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p,
cp_group, cp_global_ranks, cp_stream,
softmax_scale=None, causal=False,
deterministic=False):
"""Flash Attention implementation with context parallelism"""
out = FlashAttnUnpaddedFuncWithCP.apply(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
cp_group, cp_global_ranks, cp_stream, softmax_scale, causal, deterministic
)
return out
def _rotate_half(x: torch.Tensor) -> torch.Tensor: def _rotate_half(x: torch.Tensor) -> torch.Tensor:
""" """
change sign so the last dimension becomes [-odd, +even] change sign so the last dimension becomes [-odd, +even]
...@@ -482,6 +906,9 @@ class FlashAttention(torch.nn.Module): ...@@ -482,6 +906,9 @@ class FlashAttention(torch.nn.Module):
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
cp_group: Optional[dist_group_type] = None,
cp_global_ranks: Union[int] = None,
cp_stream: torch.cuda.Stream = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""flash-attn fprop""" """flash-attn fprop"""
...@@ -508,12 +935,6 @@ class FlashAttention(torch.nn.Module): ...@@ -508,12 +935,6 @@ class FlashAttention(torch.nn.Module):
batch_size, seqlen = query_layer.shape[0], query_layer.shape[1] batch_size, seqlen = query_layer.shape[0], query_layer.shape[1]
# [b, sq, np, hn]
query_layer, key_layer, value_layer = [
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
for x in [query_layer, key_layer, value_layer]
]
max_seqlen = seqlen max_seqlen = seqlen
cu_seqlens = torch.arange( cu_seqlens = torch.arange(
0, 0,
...@@ -522,16 +943,36 @@ class FlashAttention(torch.nn.Module): ...@@ -522,16 +943,36 @@ class FlashAttention(torch.nn.Module):
dtype=torch.int32, dtype=torch.int32,
device=query_layer.device) device=query_layer.device)
with self.attention_dropout_ctx(): if cp_group is None or get_distributed_world_size(cp_group) == 1:
fa_optional_forward_kwargs = {} # [b, sq, np, hn]
if not _flash_attn_2_available: query_layer, key_layer, value_layer = [
fa_optional_forward_kwargs["deterministic"] = self.deterministic x.view(x.shape[0] * x.shape[1], *x.shape[2:])
output = flash_attn_forward_func( for x in [query_layer, key_layer, value_layer]
query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, ]
self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor, causal=attn_mask_type=="causal", with self.attention_dropout_ctx():
**fa_optional_forward_kwargs fa_optional_forward_kwargs = {}
) if not _flash_attn_2_available:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
output = flash_attn_forward_func(
query_layer, key_layer, value_layer,
cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor,
causal=attn_mask_type=="causal",
**fa_optional_forward_kwargs
)
else:
with self.attention_dropout_ctx():
output = flash_attn_forward_func_with_cp(
query_layer, key_layer, value_layer,
cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor,
causal=attn_mask_type=="causal",
deterministic=self.deterministic
)
# [(b sq), np, hn] -> [sq, b, (np hn)] # [(b sq), np, hn] -> [sq, b, (np hn)]
return output.view(batch_size, seqlen, -1).transpose(0, 1).contiguous() return output.view(batch_size, seqlen, -1).transpose(0, 1).contiguous()
...@@ -916,6 +1357,15 @@ class DotProductAttention(torch.nn.Module): ...@@ -916,6 +1357,15 @@ class DotProductAttention(torch.nn.Module):
tensor parallel world size. tensor parallel world size.
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = `None`
tensor parallel process group. tensor parallel process group.
cp_group : ProcessGroup, default = `None`
context parallel process group.
cp_global_ranks : list of global rank IDs, default = `None`
global rank IDs of GPUs that are in cp_group.
cp_stream : CUDA stream, default = `None`
context parallelism splits flash attention into multiple steps for
compute and communication overlapping. To address the wave quantization
issue of each split step, we add an additional CUDA stream so that we
can overlap two flash attention kernels.
""" """
def __init__( def __init__(
...@@ -931,6 +1381,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -931,6 +1381,9 @@ class DotProductAttention(torch.nn.Module):
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
attention_type: str = "self", attention_type: str = "self",
cp_group: Optional[dist_group_type] = None,
cp_global_ranks: Union[int] = None,
cp_stream: torch.cuda.Stream = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -946,6 +1399,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -946,6 +1399,9 @@ class DotProductAttention(torch.nn.Module):
self.tp_group = tp_group self.tp_group = tp_group
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream
self.hidden_size_per_attention_head = kv_channels self.hidden_size_per_attention_head = kv_channels
self.num_gqa_groups = ( self.num_gqa_groups = (
...@@ -1176,9 +1632,21 @@ class DotProductAttention(torch.nn.Module): ...@@ -1176,9 +1632,21 @@ class DotProductAttention(torch.nn.Module):
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
attn_mask_type=attn_mask_type) attn_mask_type=attn_mask_type,
return self.flash_attention( cp_group=self.cp_group,
query_layer, key_layer, value_layer, attn_mask_type=attn_mask_type) cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream)
return self.flash_attention(query_layer,
key_layer,
value_layer,
attn_mask_type=attn_mask_type,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream)
assert (
self.cp_group is None or get_distributed_world_size(self.cp_group) == 1
), "Context parallelism is only implemented with Flash Attention!"
if use_fused_attention: if use_fused_attention:
if checkpoint_core_attention: if checkpoint_core_attention:
...@@ -1550,6 +2018,17 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1550,6 +2018,17 @@ class MultiheadAttention(torch.nn.Module):
"""Set TP group""" """Set TP group"""
self.tp_group = tp_group self.tp_group = tp_group
def set_context_parallel_running(
self,
cp_group: Union[dist_group_type, None],
cp_global_ranks: Union[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group and CP dual-stream running"""
self.core_attention.cp_group = cp_group
self.core_attention.cp_global_ranks = cp_global_ranks
self.core_attention.cp_stream = cp_stream
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -427,6 +427,20 @@ class TransformerLayer(torch.nn.Module): ...@@ -427,6 +427,20 @@ class TransformerLayer(torch.nn.Module):
if hasattr(child, "set_tensor_parallel_group"): if hasattr(child, "set_tensor_parallel_group"):
child.set_tensor_parallel_group(tp_group) child.set_tensor_parallel_group(tp_group)
def set_context_parallel_running(
self,
cp_group: Union[dist_group_type, None],
cp_global_ranks: Union[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group and CP dual-stream running"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
continue
if hasattr(child, "set_context_parallel_running"):
child.set_context_parallel_running(cp_group, cp_global_ranks, cp_stream)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment