Commit f572ca96 authored by zhuwenwen's avatar zhuwenwen
Browse files

update triton kernel to optimize torch cat for ds prefill

parent 787c2557
...@@ -48,6 +48,21 @@ import math ...@@ -48,6 +48,21 @@ import math
(((896, 32, 512), (896, 32, 64)), 2), (((896, 32, 512), (896, 32, 64)), 2),
(((1024, 32, 512), (1024, 32, 64)), 2), (((1024, 32, 512), (1024, 32, 64)), 2),
(((4, 32, 128), (4, 32, 64)), 2),
(((8, 32, 128), (8, 32, 64)), 2),
(((16, 32, 128), (16, 32, 64)), 2),
(((32, 32, 128), (32, 32, 64)), 2),
(((64, 32, 128), (64, 32, 64)), 2),
(((128, 32, 128), (128, 32, 64)), 2),
(((256, 32, 128), (256, 32, 64)), 2),
(((512, 32, 128), (512, 32, 64)), 2),
(((672, 32, 128), (672, 32, 64)), 2),
(((768, 32, 128), (768, 32, 64)), 2),
(((896, 32, 128), (896, 32, 64)), 2),
(((1024, 32, 128), (1024, 32, 64)), 2),
]) ])
def test_concat_Acc(shape_pair, dim): def test_concat_Acc(shape_pair, dim):
...@@ -60,6 +75,47 @@ def test_concat_Acc(shape_pair, dim): ...@@ -60,6 +75,47 @@ def test_concat_Acc(shape_pair, dim):
assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch" assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch"
@triton.jit
def concat_kernel_prefill(
A_ptr, B_ptr, C_ptr,
A_section_numel, B_section_numel, C_section_numel,
Per_block,
section_num,
BLOCK_SIZE: tl.constexpr
):
block_idx = tl.program_id(0)# 获取当前block的索引
for sub_section_index in range(Per_block//2):
sub_section_offset = block_idx * Per_block + sub_section_index * 2
if sub_section_offset <= section_num-1:
C_section_start = C_ptr + sub_section_offset * C_section_numel
A_section_start = A_ptr + sub_section_offset * A_section_numel
B_section_start = B_ptr + sub_section_offset * B_section_numel
Arrange_doubleA = tl.arange(0, 256)
mask = Arrange_doubleA < (256)
Arrange2 = (tl.arange(0, 128)[None,:] + tl.arange(0, 2)[:,None]).reshape(256)
val_from_A = tl.load(A_section_start + Arrange_doubleA)
tensorAsn = tl.full((256,), 0, tl.int32)
tensorAsn2 = tl.full((256,), (C_section_numel-1), tl.int32)
tensor_offsets = tl.where(Arrange_doubleA < A_section_numel,tensorAsn , tensorAsn2)
off = Arrange2 + tensor_offsets
tl.store(C_section_start + off,val_from_A,mask=mask)
Arrange_doubleB = tl.arange(0, 128)
mask = Arrange_doubleB < (B_section_numel*2)
val_from_B = tl.load(B_section_start + Arrange_doubleB,mask=mask)
Arrange3 = (tl.arange(0, 64)[None,:] + tl.arange(0, 2)[:,None]).reshape(128)
tensorAsn = tl.full((128,), A_section_numel, tl.int32)
tensorAsn2 = tl.full((128,), (C_section_numel + A_section_numel-1), tl.int32)
tensor_offsets = tl.where(Arrange_doubleB < B_section_numel,tensorAsn , tensorAsn2)
tl.store(C_section_start+ Arrange3 + tensor_offsets , val_from_B)
@triton.jit @triton.jit
def concat_kernel( def concat_kernel(
A_ptr, B_ptr, C_ptr, A_ptr, B_ptr, C_ptr,
...@@ -94,11 +150,25 @@ def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int): ...@@ -94,11 +150,25 @@ def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int):
output_shape = list(A.shape) output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim] output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype) C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
if dim!=0 : if dim!=0 :
block_num = reduce(lambda x, y: x * y, output_shape[:dim]) block_num = reduce(lambda x, y: x * y, output_shape[:dim])
unit_offset_A, unit_offset_B, unit_offset_C = A.stride(dim-1),B.stride(dim-1),C.stride(dim-1)
Per_block = 1 Per_block = 1
if (A.shape[1]==8 and A.shape[0] > 128) or ( A.shape[1]==16 and A.shape[0] > 96) or ( A.shape[1]==32 and A.shape[0] > 64): unit_offset_A, unit_offset_B, unit_offset_C = A.stride(dim-1),B.stride(dim-1),C.stride(dim-1)
#case prefill
if (A.shape[2] == 128 and B.shape[2] == 64 and A.shape[0] > 16):
Per_block = 8
num_blocks = math.ceil(block_num/Per_block)
concat_kernel_prefill[(num_blocks,)](
A, B, C,
unit_offset_A, unit_offset_B, unit_offset_C,
Per_block,
block_num,
BLOCK_SIZE=1024)
return C
else:
if (A.shape[1]==8 and A.shape[0] > 128) or ( A.shape[1]==16 and A.shape[0] > 96) or ( A.shape[1]==32 and A.shape[2] == 512 and A.shape[0] > 64):
Per_block = 2 Per_block = 2
num_blocks = math.ceil(block_num/Per_block) num_blocks = math.ceil(block_num/Per_block)
concat_kernel[(num_blocks,)]( concat_kernel[(num_blocks,)](
...@@ -160,7 +230,19 @@ def benchmark_32(size, provider, dim): ...@@ -160,7 +230,19 @@ def benchmark_32(size, provider, dim):
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000) return (ms*1000), (max_ms*1000), (min_ms*1000)
@triton.testing.perf_report(configs)
def benchmark_prefill(size, provider, dim):
x = torch.rand([size,32,128], device='cuda', dtype=torch.bfloat16)
y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
if __name__ == '__main__': if __name__ == '__main__':
benchmark.run(save_path="./triton_test_8",print_data=True) # benchmark.run(save_path="./triton_test_8",print_data=True)
benchmark_16.run(save_path="./triton_test_16",print_data=True) # benchmark_16.run(save_path="./triton_test_16",print_data=True)
benchmark_32.run(save_path="./triton_test_32",print_data=True) # benchmark_32.run(save_path="./triton_test_32",print_data=True)
\ No newline at end of file benchmark_prefill.run(save_path="./triton_test_prefill",print_data=True)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment