Commit 90e40f49 authored by zhuwenwen's avatar zhuwenwen
Browse files

update triton kernel to optimize torch cat for ds prefill

parent ead74dfa
...@@ -47,6 +47,21 @@ import math ...@@ -47,6 +47,21 @@ import math
(((768, 32, 512), (768, 32, 64)), 2), (((768, 32, 512), (768, 32, 64)), 2),
(((896, 32, 512), (896, 32, 64)), 2), (((896, 32, 512), (896, 32, 64)), 2),
(((1024, 32, 512), (1024, 32, 64)), 2), (((1024, 32, 512), (1024, 32, 64)), 2),
(((4, 32, 128), (4, 32, 64)), 2),
(((8, 32, 128), (8, 32, 64)), 2),
(((16, 32, 128), (16, 32, 64)), 2),
(((32, 32, 128), (32, 32, 64)), 2),
(((64, 32, 128), (64, 32, 64)), 2),
(((128, 32, 128), (128, 32, 64)), 2),
(((256, 32, 128), (256, 32, 64)), 2),
(((512, 32, 128), (512, 32, 64)), 2),
(((672, 32, 128), (672, 32, 64)), 2),
(((768, 32, 128), (768, 32, 64)), 2),
(((896, 32, 128), (896, 32, 64)), 2),
(((1024, 32, 128), (1024, 32, 64)), 2),
]) ])
def test_concat_Acc(shape_pair, dim): def test_concat_Acc(shape_pair, dim):
...@@ -60,6 +75,47 @@ def test_concat_Acc(shape_pair, dim): ...@@ -60,6 +75,47 @@ def test_concat_Acc(shape_pair, dim):
assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch" assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch"
@triton.jit
def concat_kernel_prefill(
A_ptr, B_ptr, C_ptr,
A_section_numel, B_section_numel, C_section_numel,
Per_block,
section_num,
BLOCK_SIZE: tl.constexpr
):
block_idx = tl.program_id(0)# 获取当前block的索引
for sub_section_index in range(Per_block//2):
sub_section_offset = block_idx * Per_block + sub_section_index * 2
if sub_section_offset <= section_num-1:
C_section_start = C_ptr + sub_section_offset * C_section_numel
A_section_start = A_ptr + sub_section_offset * A_section_numel
B_section_start = B_ptr + sub_section_offset * B_section_numel
Arrange_doubleA = tl.arange(0, 256)
mask = Arrange_doubleA < (256)
Arrange2 = (tl.arange(0, 128)[None,:] + tl.arange(0, 2)[:,None]).reshape(256)
val_from_A = tl.load(A_section_start + Arrange_doubleA)
tensorAsn = tl.full((256,), 0, tl.int32)
tensorAsn2 = tl.full((256,), (C_section_numel-1), tl.int32)
tensor_offsets = tl.where(Arrange_doubleA < A_section_numel,tensorAsn , tensorAsn2)
off = Arrange2 + tensor_offsets
tl.store(C_section_start + off,val_from_A,mask=mask)
Arrange_doubleB = tl.arange(0, 128)
mask = Arrange_doubleB < (B_section_numel*2)
val_from_B = tl.load(B_section_start + Arrange_doubleB,mask=mask)
Arrange3 = (tl.arange(0, 64)[None,:] + tl.arange(0, 2)[:,None]).reshape(128)
tensorAsn = tl.full((128,), A_section_numel, tl.int32)
tensorAsn2 = tl.full((128,), (C_section_numel + A_section_numel-1), tl.int32)
tensor_offsets = tl.where(Arrange_doubleB < B_section_numel,tensorAsn , tensorAsn2)
tl.store(C_section_start+ Arrange3 + tensor_offsets , val_from_B)
@triton.jit @triton.jit
def concat_kernel( def concat_kernel(
A_ptr, B_ptr, C_ptr, A_ptr, B_ptr, C_ptr,
...@@ -94,20 +150,34 @@ def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int): ...@@ -94,20 +150,34 @@ def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int):
output_shape = list(A.shape) output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim] output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype) C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
if dim!=0 : if dim!=0 :
block_num = reduce(lambda x, y: x * y, output_shape[:dim]) block_num = reduce(lambda x, y: x * y, output_shape[:dim])
unit_offset_A, unit_offset_B, unit_offset_C = A.stride(dim-1),B.stride(dim-1),C.stride(dim-1) Per_block = 1
Per_block = 1 unit_offset_A, unit_offset_B, unit_offset_C = A.stride(dim-1),B.stride(dim-1),C.stride(dim-1)
if (A.shape[1]==8 and A.shape[0] > 128) or ( A.shape[1]==16 and A.shape[0] > 96) or ( A.shape[1]==32 and A.shape[0] > 64): #case prefill
Per_block = 2 if (A.shape[2] == 128 and B.shape[2] == 64 and A.shape[0] > 16):
num_blocks = math.ceil(block_num/Per_block) Per_block = 8
concat_kernel[(num_blocks,)]( num_blocks = math.ceil(block_num/Per_block)
A, B, C, concat_kernel_prefill[(num_blocks,)](
unit_offset_A, unit_offset_B, unit_offset_C, A, B, C,
Per_block, unit_offset_A, unit_offset_B, unit_offset_C,
block_num, Per_block,
BLOCK_SIZE=1024) block_num,
return C BLOCK_SIZE=1024)
return C
else:
if (A.shape[1]==8 and A.shape[0] > 128) or ( A.shape[1]==16 and A.shape[0] > 96) or ( A.shape[1]==32 and A.shape[2] == 512 and A.shape[0] > 64):
Per_block = 2
num_blocks = math.ceil(block_num/Per_block)
concat_kernel[(num_blocks,)](
A, B, C,
unit_offset_A, unit_offset_B, unit_offset_C,
Per_block,
block_num,
BLOCK_SIZE=1024)
return C
assert False, "not support" assert False, "not support"
...@@ -160,9 +230,19 @@ def benchmark_32(size, provider, dim): ...@@ -160,9 +230,19 @@ def benchmark_32(size, provider, dim):
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000) return (ms*1000), (max_ms*1000), (min_ms*1000)
if __name__ == '__main__': @triton.testing.perf_report(configs)
benchmark.run(save_path="./triton_test_8",print_data=True) def benchmark_prefill(size, provider, dim):
benchmark_16.run(save_path="./triton_test_16",print_data=True) x = torch.rand([size,32,128], device='cuda', dtype=torch.bfloat16)
benchmark_32.run(save_path="./triton_test_32",print_data=True) y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
if __name__ == '__main__':
\ No newline at end of file # benchmark.run(save_path="./triton_test_8",print_data=True)
# benchmark_16.run(save_path="./triton_test_16",print_data=True)
# benchmark_32.run(save_path="./triton_test_32",print_data=True)
benchmark_prefill.run(save_path="./triton_test_prefill",print_data=True)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment