Unverified Commit 476f659e authored by Kunlun Li's avatar Kunlun Li Committed by GitHub
Browse files

Add THD format support for Context Parallel (#641)


Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Co-authored-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent c473f0e6
......@@ -6,5 +6,5 @@ set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 onnxruntime==1.13.1
pip install pytest==7.2.0 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
......@@ -6,6 +6,7 @@ import os, sys
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
import transformer_engine_extensions as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16}
......@@ -58,12 +59,27 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim)
kv_input_shape = (config.batch_size, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim)
attn_output_shape = (config.batch_size, config.max_seqlen_q, config.num_heads*config.head_dim)
cu_seqlens_q = None
cu_seqlens_kv = None
elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim)
kv_input_shape = (config.max_seqlen_kv, config.batch_size, config.num_gqa_groups, config.head_dim)
attn_output_shape = (config.max_seqlen_q, config.batch_size, config.num_heads*config.head_dim)
cu_seqlens_q = None
cu_seqlens_kv = None
elif qkv_format == "thd":
seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q = seqlens_q - seqlens_q % (world_size * 2)
cu_seqlens_q = torch.cat([torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0)])
cu_seqlens_kv = cu_seqlens_q
q_input_shape = (cu_seqlens_q[-1], config.num_heads, config.head_dim)
kv_input_shape = (cu_seqlens_kv[-1], config.num_gqa_groups, config.head_dim)
attn_output_shape = (cu_seqlens_q[-1], config.num_heads*config.head_dim)
cu_seqlens_q = cu_seqlens_q.to(torch.int32).cuda()
cu_seqlens_kv = cu_seqlens_kv.to(torch.int32).cuda()
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
......@@ -79,6 +95,9 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
# make sure all GPU ranks have same inputs
for x in [q, k, v, dout] + ([] if bias is None else [bias]):
dist.broadcast(x, 0, group=cp_comm_group)
if qkv_format == "thd":
for x in [cu_seqlens_q, cu_seqlens_kv]:
dist.broadcast(x, 0, group=cp_comm_group)
# run core_attn without CP
for x in [q, k, v]:
......@@ -87,28 +106,48 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
q, k, v,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
)
out.backward(dout)
# run core_attn wit CP
q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])]
bias_ = rest[0] if len(rest) else None
seq_dim = qkv_format.index('s')
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q_, k_, v_, dout_]]
seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
if qkv_format == "bshd" or qkv_format == "sbhd":
q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])]
bias_ = rest[0] if len(rest) else None
seq_dim = qkv_format.index('s')
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q_, k_, v_, dout_]]
seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
elif qkv_format == "thd":
q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]]
seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank)
seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank)
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
cu_seqlens_q = cu_seqlens_q // world_size
cu_seqlens_kv = cu_seqlens_kv // world_size
bias_ = None
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None:
bias_ = bias_.view(*bias_.shape[:-2], 2*world_size, bias_.shape[-2]//(2*world_size), bias_.shape[-1])
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
max_seqlen_q = config.max_seqlen_q
max_seqlen_kv = config.max_seqlen_kv
out_ = core_attn(
q_, k_, v_,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias_,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
)
out_.backward(dout_)
......@@ -120,11 +159,20 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == 'bf16':
tols = dict(atol=2.5e-2, rtol=2.5e-2)
dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q.grad, k.grad, v.grad, out]]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \
for x in [q_.grad, k_.grad, v_.grad, out_]]
if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q.grad, k.grad, v.grad, out]]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \
for x in [q_.grad, k_.grad, v_.grad, out_]]
elif qkv_format == "thd":
dq, out = [x.index_select(0, seq_idx_q).contiguous().view(-1) for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous().view(-1) for x in [k.grad, v.grad]]
dq_, dk_, dv_, out_ = [x.view(-1) for x in [q_.grad, k_.grad, v_.grad, out_]]
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
if qkv_format == "bshd":
torch.testing.assert_close(out_[:, 0], out[:, 0], **tols)
torch.testing.assert_close(dq_[:, 0], dq[:, 0], **tols)
......@@ -143,6 +191,11 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
torch.testing.assert_close(dq_[1], dq[1], **tols)
torch.testing.assert_close(dk_[1], dk[1], **tols)
torch.testing.assert_close(dv_[1], dv[1], **tols)
elif qkv_format == "thd":
torch.testing.assert_close(out_, out, **tols)
torch.testing.assert_close(dq_, dq, **tols)
torch.testing.assert_close(dk_, dk, **tols)
torch.testing.assert_close(dv_, dv, **tols)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
......
......@@ -33,7 +33,7 @@ def get_bash_arguments(**kwargs):
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd', 'thd'])
def test_cp_with_flash_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
......
......@@ -676,8 +676,13 @@ class AttnFuncWithCP(torch.autograd.Function):
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
if qkv_format == "thd":
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_inputs[i%2] = tex.thd_read_half_tensor(
kv_inputs[i%2], cu_seqlens_k, 0)
else:
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus:
......@@ -723,8 +728,13 @@ class AttnFuncWithCP(torch.autograd.Function):
if len(rest) > 0:
attn_biases[i] = rest[0]
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
if qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_inputs[i%2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
else:
# [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn]
q_inputs[i%2] = \
q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus:
......@@ -782,7 +792,7 @@ class AttnFuncWithCP(torch.autograd.Function):
if i == 1:
out = torch.empty_like(q).zero_()
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
if causal:
if causal and qkv_format != "thd":
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(
*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2
......@@ -791,8 +801,14 @@ class AttnFuncWithCP(torch.autograd.Function):
flash_attn_fwd_softmax_lse_correction(softmax_lse,
softmax_lse_per_step[i-1])
else:
flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :],
softmax_lse_per_step[i-1])
if qkv_format == "thd":
tex.thd_second_half_lse_correction(softmax_lse,
softmax_lse_per_step[i-1],
cu_seqlens_q,
q.size(0))
else:
flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :],
softmax_lse_per_step[i-1])
if i < cp_size:
flash_attn_streams[(i-1)%2].record_event(fwd_results_correction_done)
......@@ -800,7 +816,8 @@ class AttnFuncWithCP(torch.autograd.Function):
torch.cuda.current_stream().wait_stream(flash_attn_streams[1])
softmax_lse = softmax_lse.to(torch.float)
seq_dim = qkv_format.index("s")
if qkv_format in ["bshd", "sbhd"]:
seq_dim = qkv_format.index("s")
for i in range(cp_size):
if qkv_format == "bshd":
out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
......@@ -808,18 +825,39 @@ class AttnFuncWithCP(torch.autograd.Function):
elif qkv_format == "sbhd":
out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:])
out_ = out[1]
if i <= rank or not causal:
flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape),
out_per_step[i],
seq_dim,
softmax_lse,
softmax_lse_per_step[i])
if qkv_format in ["bshd", "sbhd"]:
flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape),
out_per_step[i],
seq_dim,
softmax_lse,
softmax_lse_per_step[i])
elif qkv_format == "thd":
tex.thd_out_correction(out,
out_per_step[i],
softmax_lse,
softmax_lse_per_step[i],
cu_seqlens_q,
False)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
else:
flash_attn_fwd_out_correction(out_,
out_per_step[i],
seq_dim,
softmax_lse_[..., 1, :],
softmax_lse_per_step[i])
if qkv_format in ["bshd", "sbhd"]:
flash_attn_fwd_out_correction(out_,
out_per_step[i],
seq_dim,
softmax_lse_[..., 1, :],
softmax_lse_per_step[i])
elif qkv_format == "thd":
tex.thd_out_correction(out,
out_per_step[i],
softmax_lse,
softmax_lse_per_step[i],
cu_seqlens_q,
True)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
kv = p2p_comm_buffers[-1]
if use_fused_attention:
......@@ -829,6 +867,7 @@ class AttnFuncWithCP(torch.autograd.Function):
out = out.view(-1, *out.shape[-3:])
else:
out = out.view(-1, *out.shape[-2:])
ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k)
ctx.rng_states = rng_states
ctx.cp_group = cp_group
......@@ -873,12 +912,17 @@ class AttnFuncWithCP(torch.autograd.Function):
attn_dbias = None
if ctx.causal:
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
if ctx.use_fused_attention:
# [b, np, sq//2] -> [b, np, sq//2, 1]
softmax_lse_.unsqueeze_(-1)
if ctx.qkv_format == "thd":
softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0))
else:
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = \
softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
if ctx.use_fused_attention:
# [b, np, sq//2] -> [b, np, sq//2, 1]
softmax_lse_.unsqueeze_(-1)
if ctx.use_fused_attention:
# [b, np, sq] -> [b, np, sq, 1]
softmax_lse.unsqueeze_(-1)
......@@ -1013,8 +1057,12 @@ class AttnFuncWithCP(torch.autograd.Function):
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
if ctx.qkv_format == "thd":
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0)
else:
# [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
......@@ -1063,15 +1111,23 @@ class AttnFuncWithCP(torch.autograd.Function):
attn_bias_type=ctx.attn_bias_type,
)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
if ctx.qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
if ctx.qkv_format == "thd":
out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1)
dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1]
_flash_attn_backward(
......@@ -1144,16 +1200,22 @@ class AttnFuncWithCP(torch.autograd.Function):
elif ctx.qkv_format == "sbhd":
dq[0].copy_(dq_[0])
dq[1].add_(dq_[1])
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "copy", "add")
elif i > 0:
if ctx.qkv_format == "bshd":
dq[:, 1, ...].add_(dq_)
elif ctx.qkv_format == "sbhd":
dq[1].add_(dq_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "add")
else:
if ctx.qkv_format == "bshd":
dq[:, 1, ...].copy_(dq_)
elif ctx.qkv_format == "sbhd":
dq[1].copy_(dq_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "copy")
else:
if i == 0:
dq.copy_(dq_)
......@@ -1204,6 +1266,8 @@ class AttnFuncWithCP(torch.autograd.Function):
elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].add_(dkv_[:, 0, ...])
dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "copy")
else:
dkv.add_(dkv_)
elif i >= (cp_size-rank-1):
......@@ -1212,11 +1276,15 @@ class AttnFuncWithCP(torch.autograd.Function):
dkv[:, :, 0, ...].copy_(dkv_)
elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].copy_(dkv_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "copy", "none")
else:
if ctx.qkv_format == "bshd":
dkv[:, :, 0, ...].add_(dkv_)
elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].add_(dkv_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "none")
elif i > 0:
dkv.add_(dkv_)
else:
......@@ -1254,10 +1322,12 @@ def attn_forward_func_with_cp(
use_fused_attention=False
) -> torch.Tensor:
"""Attention implementation with context parallelism"""
assert(qkv_format in ["bshd", "sbhd"]
assert(qkv_format in ["bshd", "sbhd", "thd"]
), f"QKV format of {qkv_format} is not supported with context parallelism!"
assert(qkv_format != "sbhd" or use_fused_attention
), "FlashAttention does not support sbhd format!"
assert(not(qkv_format == "thd" and use_fused_attention)
), "FusedAttention does not support thd format!"
assert (attn_mask_type in ["causal", "no_mask"]
), f"Mask type of {attn_mask_type} is not supported with context parallelism!"
assert (attn_bias is None or use_fused_attention
......@@ -2054,7 +2124,6 @@ class FlashAttention(torch.nn.Module):
key_layer.device,
)
elif qkv_format == 'thd':
assert not context_parallel, "thd format not supported with context parallelism!"
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
if max_seqlen_q is None:
......
......@@ -637,3 +637,45 @@ size_t get_cudnn_version();
bool userbuf_comm_available();
void placeholder();
/***************************************************************************************************
* Support THD format for Context Parallel
**************************************************************************************************/
at::Tensor thd_read_half_tensor(const at::Tensor &tensor,
const at::Tensor &cu_seqlens,
int half_idx
);
void thd_second_half_lse_correction(at::Tensor lse,
const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens,
int total_tokens
);
at::Tensor thd_read_second_half_lse(const at::Tensor &lse,
const at::Tensor &cu_seqlens,
int total_tokens
);
void thd_out_correction(at::Tensor out,
const at::Tensor &out_per_step,
const at::Tensor &lse,
const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens,
bool only_second_half
);
void thd_grad_correction(at::Tensor grad,
const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens,
const std::string &first_half,
const std::string &second_half
);
at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens,
int total_tokens,
int world_size,
int rank
);
......@@ -1440,3 +1440,622 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) {
return qkv;
}
/***************************************************************************************************
* Support THD format for Context Parallel: Binary search
**************************************************************************************************/
__forceinline__
__device__ int binary_search(int target, int *array, int len) {
int left = 1, right = len - 1;
while (left < right) {
int mid = (left + right) / 2;
if (array[mid] <= target) {
left = mid + 1;
} else {
right = mid;
}
}
return left - 1;
}
/***************************************************************************************************
* Support THD format for Context Parallel: Read the half of a THD tensor
**************************************************************************************************/
__global__ void thd_read_half_tensor_kernel(void *half,
void *tensor,
int *cu_seqlens,
int batch,
int hidden_size_in_bytes,
int half_idx,
int dim_size_of_token) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
}
__syncthreads();
int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
int laneid = threadIdx.x % 32;
int num_warps = (blockDim.x * gridDim.x) / 32;
int num_total_tokens = cu_seqlens_s[batch];
int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4);
size_t offset = static_cast<size_t>(dim_size_of_token) * hidden_size_in_bytes;
half = reinterpret_cast<void*>(reinterpret_cast<char*>(half) + offset/2 * blockIdx.y);
tensor = reinterpret_cast<void*>(reinterpret_cast<char*>(tensor) + offset * blockIdx.y);
for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) {
int seqid = binary_search(token_id, cu_seqlens_s, batch + 1);
size_t offset_in_bytes = static_cast<size_t>(token_id) * hidden_size_in_bytes;
float4* cur_half_token = reinterpret_cast<float4*>(reinterpret_cast<char*>(half) + \
offset_in_bytes);
offset_in_bytes = (static_cast<size_t>(token_id) + cu_seqlens_s[seqid + half_idx]) * \
hidden_size_in_bytes;
float4* cur_token = reinterpret_cast<float4*>(reinterpret_cast<char*>(tensor) + \
offset_in_bytes);
for (int idx = laneid; idx < num_float4s_per_token; idx += 32) {
cur_half_token[idx] = cur_token[idx];
}
}
}
at::Tensor thd_read_half_tensor(const at::Tensor &tensor,
const at::Tensor &cu_seqlens,
int half_idx) {
NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
NVTE_CHECK(cu_seqlens.size(0) >= 2);
// Shapes of q and dq are [t, h, d], so the dimension of "t" is 0
// Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1
int seq_dim = tensor.dim() == 3 ? 0 : 1;
int batch = cu_seqlens.size(0) - 1;
int num_heads = tensor.size(seq_dim + 1);
int dim_per_head = tensor.size(seq_dim + 2);
int hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type());
// For 128-bits load/store
NVTE_CHECK(hidden_size_in_bytes % 16 == 0);
// Generate output
std::vector<int64_t> shape(tensor.dim());
for (size_t i = 0; i < shape.size(); i++) {
shape[i] = tensor.size(i);
}
shape[seq_dim] /= 2;
at::Tensor half = at::empty(shape, at::CUDA(tensor.scalar_type()));
// Launch Kernel
constexpr unsigned int block = 256;
unsigned int grid_x = (tensor.size(seq_dim) / 2 * 32 + block - 1) / block;
unsigned int grid_y = 1;
for (int i = 0; i < seq_dim; i++) {
grid_y *= tensor.size(i);
}
dim3 grid = {grid_x, grid_y};
thd_read_half_tensor_kernel<<<grid, block, sizeof(int) * (batch+1),
at::cuda::getCurrentCUDAStream()>>>(
half.data_ptr(),
tensor.data_ptr(),
cu_seqlens.data_ptr<int>(),
batch,
hidden_size_in_bytes,
half_idx,
tensor.size(seq_dim));
return half;
}
/***************************************************************************************************
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/
template <typename lse_dtype, typename Functor>
__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens,
int batch, int num_heads, int max_seqlen) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
}
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int num_threads = blockDim.x * gridDim.x;
int num_total_tokens = cu_seqlens_s[batch];
for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
size_t idx = row * max_seqlen + col + seq_len;
size_t half_idx = row * max_seqlen / 2 + col;
Functor::run(lse, half_lse, idx, half_idx);
}
}
}
struct LseCorrectionFunctor {
__forceinline__
__device__ static void run(double *lse, float *half_lse, size_t idx, size_t half_idx) {
double val = lse[idx];
float val_per_step = half_lse[half_idx];
double max_scale = max(val, val_per_step);
double min_scale = min(val, val_per_step);
lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale));
}
};
void thd_second_half_lse_correction(at::Tensor lse,
const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens,
int total_tokens) {
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double);
NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(lse.dim() == 3);
NVTE_CHECK(lse_per_step.dim() == 3);
NVTE_CHECK(cu_seqlens.dim() == 1);
int batch = lse.size(0);
int num_heads = lse.size(1);
int max_seqlen = lse.size(2);
NVTE_CHECK(lse_per_step.size(0) == batch);
NVTE_CHECK(lse_per_step.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(2) == max_seqlen / 2);
NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
constexpr unsigned int block = 256;
unsigned int grid_x = (total_tokens / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
thd_lse_kernel<double, LseCorrectionFunctor><<<grid, block, sizeof(int) * (batch+1),
at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<double>(),
lse_per_step.data_ptr<float>(),
cu_seqlens.data_ptr<int>(),
batch,
num_heads,
max_seqlen);
}
struct ReadLseFunctor {
__forceinline__
__device__ static void run(float *lse, float *half_lse, size_t idx, size_t half_idx) {
half_lse[half_idx] = lse[idx];
}
};
at::Tensor thd_read_second_half_lse(const at::Tensor &lse,
const at::Tensor &cu_seqlens,
int total_tokens) {
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(lse.dim() == 3);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
int batch = lse.size(0);
int num_heads = lse.size(1);
int max_seqlen = lse.size(2);
NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
std::vector<int64_t> shape = {batch, num_heads, max_seqlen / 2};
at::Tensor half_lse = at::zeros(shape, at::CUDA(lse.scalar_type()));
constexpr unsigned int block = 256;
unsigned int grid_x = (total_tokens / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
thd_lse_kernel<float, ReadLseFunctor><<<grid, block, sizeof(int) * (batch+1),
at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(),
half_lse.data_ptr<float>(),
cu_seqlens.data_ptr<int>(),
batch,
num_heads,
max_seqlen);
return half_lse;
}
/***************************************************************************************************
* Support THD format for Context Parallel: Out correction in forward
**************************************************************************************************/
template <typename dtype, int only_second_half, int tile_size>
__global__ void thd_out_correction_kernel(dtype *out,
dtype *out_per_step,
float *lse,
float *lse_per_step,
int *cu_seqlens,
int batch,
int num_heads,
int dim_per_head,
int max_seqlen) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1);
}
__syncthreads();
int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size;
int lane_id = threadIdx.x % tile_size;
int num_tiles = (blockDim.x * gridDim.x) / tile_size;
int num_total_tokens = cu_seqlens_s[batch];
int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4);
for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t idx, idx_per_step;
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
idx = row * max_seqlen + col + seq_len * only_second_half;
idx_per_step = row * max_seqlen / (only_second_half + 1) + col;
float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]);
idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
idx = (idx * num_heads + head_id) * dim_per_head;
idx_per_step = (static_cast<size_t>(token_id) * num_heads + head_id) * dim_per_head;
dtype *cur_out = out + idx;
dtype *cur_out_per_step = out_per_step + idx_per_step;
for (int j = lane_id; j < num_loops_per_head; j += tile_size) {
float4 data_per_step = reinterpret_cast<float4*>(cur_out_per_step)[j];
float4 data = reinterpret_cast<float4*>(cur_out)[j];
dtype *p_per_step = reinterpret_cast<dtype*>(&data_per_step);
dtype *p = reinterpret_cast<dtype*>(&data);
for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) {
p[k] += p_per_step[k] * lse_corrected_exp;
}
reinterpret_cast<float4*>(cur_out)[j] = data;
}
}
}
}
template<typename dtype, int only_second_half>
static void thd_out_correction_helper(at::Tensor out,
const at::Tensor &out_per_step,
const at::Tensor &lse,
const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens) {
NVTE_CHECK(out.scalar_type() == out_per_step.scalar_type());
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
int total_tokens = out.size(0);
int num_heads = out.size(1);
int dim_per_head = out.size(2);
int batch = lse.size(0);
int max_seqlen = lse.size(2);
NVTE_CHECK(out_per_step.size(0) == total_tokens / (only_second_half + 1));
NVTE_CHECK(out_per_step.size(1) == num_heads);
NVTE_CHECK(out_per_step.size(2) == dim_per_head);
NVTE_CHECK(lse.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(0) == batch);
NVTE_CHECK(lse_per_step.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(2) == max_seqlen / (only_second_half + 1));
NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
constexpr int tile = 16;
constexpr int block = 512;
unsigned int grid_x = (static_cast<size_t>(total_tokens) / (only_second_half + 1) * \
tile + block - 1) / block;
dim3 grid = {grid_x, (unsigned int)num_heads};
thd_out_correction_kernel<dtype, only_second_half, tile><<<grid, block, sizeof(int) * (batch+1),
at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(),
out_per_step.data_ptr<dtype>(),
lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(),
cu_seqlens.data_ptr<int>(),
batch,
num_heads,
dim_per_head,
max_seqlen);
}
void thd_out_correction(at::Tensor out,
const at::Tensor &out_per_step,
const at::Tensor &lse,
const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens,
bool only_second_half) {
if (only_second_half) {
if (out.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens);
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens);
} else if (out.scalar_type() == at::ScalarType::Float) {
using dtype = float;
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens);
} else {
NVTE_ERROR("Unsupported dtype of out\n");
}
} else {
if (out.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens);
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens);
} else if (out.scalar_type() == at::ScalarType::Float) {
using dtype = float;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens);
} else {
NVTE_ERROR("Unsupported dtype of out\n");
}
}
}
/***************************************************************************************************
* Support THD format for Context Parallel: Gradients correction in backward
**************************************************************************************************/
template <typename dtype, typename Functor_0, typename Functor_1, int functor_idx, int group_size>
__global__ void thd_grad_correction_kernel(dtype *grad,
dtype *grad_per_step,
int *cu_seqlens,
int batch,
int hidden_size,
int dim_size_of_token) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
if constexpr (functor_idx < 2) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
} else {
cu_seqlens_s[i] = cu_seqlens[i];
}
}
__syncthreads();
int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / group_size;
int lane_id = threadIdx.x % group_size;
int num_groups = (blockDim.x * gridDim.x) / group_size;
int num_total_tokens = cu_seqlens_s[batch];
int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4);
size_t offset = static_cast<size_t>(dim_size_of_token) * hidden_size;
if constexpr (functor_idx < 2) {
grad_per_step = grad_per_step + offset / 2 * blockIdx.y;
} else {
grad_per_step = grad_per_step + offset * blockIdx.y;
}
grad = grad + offset * blockIdx.y;
for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
int token_offset;
bool is_first_half;
if constexpr (functor_idx < 2) {
token_offset = cu_seqlens_s[seq_id + functor_idx];
is_first_half = (functor_idx == 0);
} else {
token_offset = 0;
int len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
is_first_half = (token_id - cu_seqlens_s[seq_id]) < (len / 2);
}
dtype *token = &grad[(token_id + token_offset) * static_cast<size_t>(hidden_size)];
dtype *token_per_step = &grad_per_step[token_id * static_cast<size_t>(hidden_size)];
for (int idx = lane_id; idx < num_inner_loops; idx += group_size) {
if (is_first_half) {
Functor_0::run(token, token_per_step, idx);
} else {
Functor_1::run(token, token_per_step, idx);
}
}
}
}
struct EmptyFunctor {
__forceinline__
__device__ static void run(void *token, void *token_per_step, int idx) {}
};
struct CopyFunctor {
__forceinline__
__device__ static void run(void *token, void *token_per_step, int idx) {
reinterpret_cast<float4*>(token)[idx] = reinterpret_cast<float4*>(token_per_step)[idx];
}
};
template <typename dtype>
struct AddFunctor {
__forceinline__
__device__ static void run(dtype *token, dtype *token_per_step, int idx) {
float4 d_ = reinterpret_cast<float4*>(token)[idx];
dtype *p_ = reinterpret_cast<dtype*>(&d_);
float4 d = reinterpret_cast<float4*>(token_per_step)[idx];
dtype *p = reinterpret_cast<dtype*>(&d);
#pragma unroll
for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) {
p_[i] += p[i];
}
reinterpret_cast<float4*>(token)[idx] = d_;
}
};
template <typename dtype, typename Functor_0, typename Functor_1, int functor_idx>
static void thd_grad_correction_helper(at::Tensor grad,
const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens) {
NVTE_CHECK(grad.dim() == 3 || grad.dim() == 4);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
// Shape of dq is [t, h, d], so the dimension of "t" is 0
// Shape of dkv is [2, t, h, d], so the dimension of "t" is 1
int seq_dim = grad.dim() == 3 ? 0 : 1;
int total_tokens = grad.size(seq_dim);
int num_heads = grad.size(seq_dim + 1);
int dim_per_head = grad.size(seq_dim + 2);
int batch = cu_seqlens.size(0) - 1;
if constexpr (functor_idx < 2) {
NVTE_CHECK(grad_per_step.size(seq_dim) == total_tokens / 2);
} else {
NVTE_CHECK(grad_per_step.size(seq_dim) == total_tokens);
}
NVTE_CHECK(grad_per_step.size(seq_dim + 1) == num_heads);
NVTE_CHECK(grad_per_step.size(seq_dim + 2) == dim_per_head);
size_t hidden_size = num_heads * dim_per_head;
NVTE_CHECK((hidden_size * c10::elementSize(grad.scalar_type())) % 16 == 0);
constexpr unsigned int block = 256;
unsigned int grid_x;
if constexpr (functor_idx < 2) {
grid_x = (total_tokens / 2 * 32 + block - 1) / block;
} else {
grid_x = (total_tokens * 32 + block - 1) / block;
}
unsigned int grid_y = 1;
for (int i = 0; i < seq_dim; i++) {
grid_y *= grad.size(i);
}
dim3 grid = {grid_x, grid_y};
thd_grad_correction_kernel<dtype, Functor_0, Functor_1, functor_idx, 32>
<<<grid, block, sizeof(int) * (batch+1), at::cuda::getCurrentCUDAStream()>>>(
grad.data_ptr<dtype>(),
grad_per_step.data_ptr<dtype>(),
cu_seqlens.data_ptr<int>(),
batch,
hidden_size,
total_tokens);
}
template <typename dtype>
static void thd_grad_dispatcher(at::Tensor grad,
const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens,
const std::string &first_half,
const std::string &second_half) {
if (first_half == "add" && second_half == "none") {
thd_grad_correction_helper<dtype, AddFunctor<dtype>, EmptyFunctor, 0>(
grad, grad_per_step, cu_seqlens);
} else if (first_half == "copy" && second_half == "none") {
thd_grad_correction_helper<dtype, CopyFunctor, EmptyFunctor, 0>(
grad, grad_per_step, cu_seqlens);
} else if (first_half == "none" && second_half == "add") {
thd_grad_correction_helper<dtype, EmptyFunctor, AddFunctor<dtype>, 1>(
grad, grad_per_step, cu_seqlens);
} else if (first_half == "none" && second_half == "copy") {
thd_grad_correction_helper<dtype, EmptyFunctor, CopyFunctor, 1>(
grad, grad_per_step, cu_seqlens);
} else if (first_half == "add" && second_half == "copy") {
thd_grad_correction_helper<dtype, AddFunctor<dtype>, CopyFunctor, 2>(
grad, grad_per_step, cu_seqlens);
} else if (first_half == "copy" && second_half == "add") {
thd_grad_correction_helper<dtype, CopyFunctor, AddFunctor<dtype>, 2>(
grad, grad_per_step, cu_seqlens);
} else {
NVTE_ERROR("Unsupported Functor of first half and second_half\n");
}
}
void thd_grad_correction(at::Tensor grad,
const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens,
const std::string &first_half,
const std::string &second_half) {
if (grad.scalar_type() == at::ScalarType::Half) {
thd_grad_dispatcher<at::Half>(grad, grad_per_step, cu_seqlens, first_half, second_half);
} else if (grad.scalar_type() == at::ScalarType::BFloat16) {
thd_grad_dispatcher<at::BFloat16>(grad, grad_per_step, cu_seqlens, first_half, second_half);
} else if (grad.scalar_type() == at::ScalarType::Float) {
thd_grad_dispatcher<float>(grad, grad_per_step, cu_seqlens, first_half, second_half);
} else {
NVTE_ERROR("Unsupported dtype of grad\n");
}
}
/***************************************************************************************************
* Support THD format for Context Parallel: Generate partitioned indices for input tokens
**************************************************************************************************/
__global__ void thd_partition_indices_kernel(int *output,
int *cu_seqlens,
int batch,
int total_tokens,
int world_size,
int rank) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
int seqlen = cu_seqlens[i];
// Currently we assume that each sequence length is divisible by (world_size*2) since we have
// to distribute each sequence evenly to different GPUs.
assert(seqlen % (world_size*2) == 0);
cu_seqlens_s[i] = seqlen / world_size;
}
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int num_threads = blockDim.x * gridDim.x;
for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
int index = token_id - cu_seqlens_s[seq_id];
int offset = index < seq_len/2 ? rank : (world_size-1) * 2 - rank;
index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset;
output[token_id] = index;
}
}
at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens,
int total_tokens,
int world_size,
int rank) {
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
NVTE_CHECK(cu_seqlens.size(0) >= 2);
NVTE_CHECK(rank >= 0 && rank < world_size);
NVTE_CHECK(world_size > 0);
NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0);
int batch = cu_seqlens.size(0) - 1;
std::vector<int64_t> shape = {total_tokens / world_size};
at::Tensor output = at::empty(shape, at::CUDA(at::ScalarType::Int));
constexpr unsigned int block = 256;
unsigned int grid = (output.size(0) + block - 1) / block;
thd_partition_indices_kernel<<<grid, block, sizeof(int) * (batch+1),
at::cuda::getCurrentCUDAStream()>>>(
output.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
batch,
total_tokens,
world_size,
rank);
return output;
}
......@@ -102,6 +102,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version");
m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available");
// Support THD format for Context Parallel
m.def("thd_read_half_tensor", &thd_read_half_tensor,
"Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD "
"tensor");
m.def("thd_second_half_lse_correction", &thd_second_half_lse_correction,
"Correct the second half of the softmax_lse");
m.def("thd_read_second_half_lse", &thd_read_second_half_lse,
"Read the second half of the softmax_lse");
m.def("thd_out_correction", &thd_out_correction,
"Correct the THD format output of context parallelism in forward pass");
m.def("thd_grad_correction", &thd_grad_correction,
"Correct the THD format gradients of context parallelism in backward pass");
m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices,
"Generate partitioned indices for inputs in THD format");
// Data structures
py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta")
.def(py::init<>())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment