Unverified Commit 476f659e authored by Kunlun Li's avatar Kunlun Li Committed by GitHub
Browse files

Add THD format support for Context Parallel (#641)


Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Co-authored-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent c473f0e6
......@@ -6,5 +6,5 @@ set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 onnxruntime==1.13.1
pip install pytest==7.2.0 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
......@@ -6,6 +6,7 @@ import os, sys
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
import transformer_engine_extensions as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16}
......@@ -58,12 +59,27 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim)
kv_input_shape = (config.batch_size, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim)
attn_output_shape = (config.batch_size, config.max_seqlen_q, config.num_heads*config.head_dim)
cu_seqlens_q = None
cu_seqlens_kv = None
elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim)
kv_input_shape = (config.max_seqlen_kv, config.batch_size, config.num_gqa_groups, config.head_dim)
attn_output_shape = (config.max_seqlen_q, config.batch_size, config.num_heads*config.head_dim)
cu_seqlens_q = None
cu_seqlens_kv = None
elif qkv_format == "thd":
seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q = seqlens_q - seqlens_q % (world_size * 2)
cu_seqlens_q = torch.cat([torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0)])
cu_seqlens_kv = cu_seqlens_q
q_input_shape = (cu_seqlens_q[-1], config.num_heads, config.head_dim)
kv_input_shape = (cu_seqlens_kv[-1], config.num_gqa_groups, config.head_dim)
attn_output_shape = (cu_seqlens_q[-1], config.num_heads*config.head_dim)
cu_seqlens_q = cu_seqlens_q.to(torch.int32).cuda()
cu_seqlens_kv = cu_seqlens_kv.to(torch.int32).cuda()
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
......@@ -79,6 +95,9 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
# make sure all GPU ranks have same inputs
for x in [q, k, v, dout] + ([] if bias is None else [bias]):
dist.broadcast(x, 0, group=cp_comm_group)
if qkv_format == "thd":
for x in [cu_seqlens_q, cu_seqlens_kv]:
dist.broadcast(x, 0, group=cp_comm_group)
# run core_attn without CP
for x in [q, k, v]:
......@@ -87,28 +106,48 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
q, k, v,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
)
out.backward(dout)
# run core_attn wit CP
q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])]
bias_ = rest[0] if len(rest) else None
seq_dim = qkv_format.index('s')
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q_, k_, v_, dout_]]
seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
if qkv_format == "bshd" or qkv_format == "sbhd":
q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])]
bias_ = rest[0] if len(rest) else None
seq_dim = qkv_format.index('s')
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q_, k_, v_, dout_]]
seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
elif qkv_format == "thd":
q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]]
seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank)
seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank)
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
cu_seqlens_q = cu_seqlens_q // world_size
cu_seqlens_kv = cu_seqlens_kv // world_size
bias_ = None
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None:
bias_ = bias_.view(*bias_.shape[:-2], 2*world_size, bias_.shape[-2]//(2*world_size), bias_.shape[-1])
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
max_seqlen_q = config.max_seqlen_q
max_seqlen_kv = config.max_seqlen_kv
out_ = core_attn(
q_, k_, v_,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias_,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
)
out_.backward(dout_)
......@@ -120,11 +159,20 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == 'bf16':
tols = dict(atol=2.5e-2, rtol=2.5e-2)
dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q.grad, k.grad, v.grad, out]]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \
for x in [q_.grad, k_.grad, v_.grad, out_]]
if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q.grad, k.grad, v.grad, out]]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \
for x in [q_.grad, k_.grad, v_.grad, out_]]
elif qkv_format == "thd":
dq, out = [x.index_select(0, seq_idx_q).contiguous().view(-1) for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous().view(-1) for x in [k.grad, v.grad]]
dq_, dk_, dv_, out_ = [x.view(-1) for x in [q_.grad, k_.grad, v_.grad, out_]]
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
if qkv_format == "bshd":
torch.testing.assert_close(out_[:, 0], out[:, 0], **tols)
torch.testing.assert_close(dq_[:, 0], dq[:, 0], **tols)
......@@ -143,6 +191,11 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
torch.testing.assert_close(dq_[1], dq[1], **tols)
torch.testing.assert_close(dk_[1], dk[1], **tols)
torch.testing.assert_close(dv_[1], dv[1], **tols)
elif qkv_format == "thd":
torch.testing.assert_close(out_, out, **tols)
torch.testing.assert_close(dq_, dq, **tols)
torch.testing.assert_close(dk_, dk, **tols)
torch.testing.assert_close(dv_, dv, **tols)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
......
......@@ -33,7 +33,7 @@ def get_bash_arguments(**kwargs):
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd', 'thd'])
def test_cp_with_flash_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
......
......@@ -676,8 +676,13 @@ class AttnFuncWithCP(torch.autograd.Function):
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
if qkv_format == "thd":
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_inputs[i%2] = tex.thd_read_half_tensor(
kv_inputs[i%2], cu_seqlens_k, 0)
else:
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus:
......@@ -723,8 +728,13 @@ class AttnFuncWithCP(torch.autograd.Function):
if len(rest) > 0:
attn_biases[i] = rest[0]
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
if qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_inputs[i%2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
else:
# [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn]
q_inputs[i%2] = \
q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus:
......@@ -782,7 +792,7 @@ class AttnFuncWithCP(torch.autograd.Function):
if i == 1:
out = torch.empty_like(q).zero_()
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
if causal:
if causal and qkv_format != "thd":
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(
*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2
......@@ -791,8 +801,14 @@ class AttnFuncWithCP(torch.autograd.Function):
flash_attn_fwd_softmax_lse_correction(softmax_lse,
softmax_lse_per_step[i-1])
else:
flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :],
softmax_lse_per_step[i-1])
if qkv_format == "thd":
tex.thd_second_half_lse_correction(softmax_lse,
softmax_lse_per_step[i-1],
cu_seqlens_q,
q.size(0))
else:
flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :],
softmax_lse_per_step[i-1])
if i < cp_size:
flash_attn_streams[(i-1)%2].record_event(fwd_results_correction_done)
......@@ -800,7 +816,8 @@ class AttnFuncWithCP(torch.autograd.Function):
torch.cuda.current_stream().wait_stream(flash_attn_streams[1])
softmax_lse = softmax_lse.to(torch.float)
seq_dim = qkv_format.index("s")
if qkv_format in ["bshd", "sbhd"]:
seq_dim = qkv_format.index("s")
for i in range(cp_size):
if qkv_format == "bshd":
out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
......@@ -808,18 +825,39 @@ class AttnFuncWithCP(torch.autograd.Function):
elif qkv_format == "sbhd":
out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:])
out_ = out[1]
if i <= rank or not causal:
flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape),
out_per_step[i],
seq_dim,
softmax_lse,
softmax_lse_per_step[i])
if qkv_format in ["bshd", "sbhd"]:
flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape),
out_per_step[i],
seq_dim,
softmax_lse,
softmax_lse_per_step[i])
elif qkv_format == "thd":
tex.thd_out_correction(out,
out_per_step[i],
softmax_lse,
softmax_lse_per_step[i],
cu_seqlens_q,
False)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
else:
flash_attn_fwd_out_correction(out_,
out_per_step[i],
seq_dim,
softmax_lse_[..., 1, :],
softmax_lse_per_step[i])
if qkv_format in ["bshd", "sbhd"]:
flash_attn_fwd_out_correction(out_,
out_per_step[i],
seq_dim,
softmax_lse_[..., 1, :],
softmax_lse_per_step[i])
elif qkv_format == "thd":
tex.thd_out_correction(out,
out_per_step[i],
softmax_lse,
softmax_lse_per_step[i],
cu_seqlens_q,
True)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
kv = p2p_comm_buffers[-1]
if use_fused_attention:
......@@ -829,6 +867,7 @@ class AttnFuncWithCP(torch.autograd.Function):
out = out.view(-1, *out.shape[-3:])
else:
out = out.view(-1, *out.shape[-2:])
ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k)
ctx.rng_states = rng_states
ctx.cp_group = cp_group
......@@ -873,12 +912,17 @@ class AttnFuncWithCP(torch.autograd.Function):
attn_dbias = None
if ctx.causal:
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
if ctx.use_fused_attention:
# [b, np, sq//2] -> [b, np, sq//2, 1]
softmax_lse_.unsqueeze_(-1)
if ctx.qkv_format == "thd":
softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0))
else:
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = \
softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
if ctx.use_fused_attention:
# [b, np, sq//2] -> [b, np, sq//2, 1]
softmax_lse_.unsqueeze_(-1)
if ctx.use_fused_attention:
# [b, np, sq] -> [b, np, sq, 1]
softmax_lse.unsqueeze_(-1)
......@@ -1013,8 +1057,12 @@ class AttnFuncWithCP(torch.autograd.Function):
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
if ctx.qkv_format == "thd":
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0)
else:
# [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
......@@ -1063,15 +1111,23 @@ class AttnFuncWithCP(torch.autograd.Function):
attn_bias_type=ctx.attn_bias_type,
)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
if ctx.qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
if ctx.qkv_format == "thd":
out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1)
dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1]
_flash_attn_backward(
......@@ -1144,16 +1200,22 @@ class AttnFuncWithCP(torch.autograd.Function):
elif ctx.qkv_format == "sbhd":
dq[0].copy_(dq_[0])
dq[1].add_(dq_[1])
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "copy", "add")
elif i > 0:
if ctx.qkv_format == "bshd":
dq[:, 1, ...].add_(dq_)
elif ctx.qkv_format == "sbhd":
dq[1].add_(dq_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "add")
else:
if ctx.qkv_format == "bshd":
dq[:, 1, ...].copy_(dq_)
elif ctx.qkv_format == "sbhd":
dq[1].copy_(dq_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "copy")
else:
if i == 0:
dq.copy_(dq_)
......@@ -1204,6 +1266,8 @@ class AttnFuncWithCP(torch.autograd.Function):
elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].add_(dkv_[:, 0, ...])
dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "copy")
else:
dkv.add_(dkv_)
elif i >= (cp_size-rank-1):
......@@ -1212,11 +1276,15 @@ class AttnFuncWithCP(torch.autograd.Function):
dkv[:, :, 0, ...].copy_(dkv_)
elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].copy_(dkv_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "copy", "none")
else:
if ctx.qkv_format == "bshd":
dkv[:, :, 0, ...].add_(dkv_)
elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].add_(dkv_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "none")
elif i > 0:
dkv.add_(dkv_)
else:
......@@ -1254,10 +1322,12 @@ def attn_forward_func_with_cp(
use_fused_attention=False
) -> torch.Tensor:
"""Attention implementation with context parallelism"""
assert(qkv_format in ["bshd", "sbhd"]
assert(qkv_format in ["bshd", "sbhd", "thd"]
), f"QKV format of {qkv_format} is not supported with context parallelism!"
assert(qkv_format != "sbhd" or use_fused_attention
), "FlashAttention does not support sbhd format!"
assert(not(qkv_format == "thd" and use_fused_attention)
), "FusedAttention does not support thd format!"
assert (attn_mask_type in ["causal", "no_mask"]
), f"Mask type of {attn_mask_type} is not supported with context parallelism!"
assert (attn_bias is None or use_fused_attention
......@@ -2054,7 +2124,6 @@ class FlashAttention(torch.nn.Module):
key_layer.device,
)
elif qkv_format == 'thd':
assert not context_parallel, "thd format not supported with context parallelism!"
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
if max_seqlen_q is None:
......
......@@ -637,3 +637,45 @@ size_t get_cudnn_version();
bool userbuf_comm_available();
void placeholder();
/***************************************************************************************************
* Support THD format for Context Parallel
**************************************************************************************************/
at::Tensor thd_read_half_tensor(const at::Tensor &tensor,
const at::Tensor &cu_seqlens,
int half_idx
);
void thd_second_half_lse_correction(at::Tensor lse,
const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens,
int total_tokens
);
at::Tensor thd_read_second_half_lse(const at::Tensor &lse,
const at::Tensor &cu_seqlens,
int total_tokens
);
void thd_out_correction(at::Tensor out,
const at::Tensor &out_per_step,
const at::Tensor &lse,
const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens,
bool only_second_half
);
void thd_grad_correction(at::Tensor grad,
const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens,
const std::string &first_half,
const std::string &second_half
);
at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens,
int total_tokens,
int world_size,
int rank
);
......@@ -102,6 +102,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version");
m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available");
// Support THD format for Context Parallel
m.def("thd_read_half_tensor", &thd_read_half_tensor,
"Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD "
"tensor");
m.def("thd_second_half_lse_correction", &thd_second_half_lse_correction,
"Correct the second half of the softmax_lse");
m.def("thd_read_second_half_lse", &thd_read_second_half_lse,
"Read the second half of the softmax_lse");
m.def("thd_out_correction", &thd_out_correction,
"Correct the THD format output of context parallelism in forward pass");
m.def("thd_grad_correction", &thd_grad_correction,
"Correct the THD format gradients of context parallelism in backward pass");
m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices,
"Generate partitioned indices for inputs in THD format");
// Data structures
py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta")
.def(py::init<>())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment