Commit 2c169409 authored by zhuwenwen's avatar zhuwenwen
Browse files

update the cat implementation of triton's non contiguous memory for the decode phase

parent e5f51b79
...@@ -50,26 +50,34 @@ import math ...@@ -50,26 +50,34 @@ import math
(((4, 32, 128), (4, 32, 64)), 2),
(((8, 32, 128), (8, 32, 64)), 2),
(((16, 32, 128), (16, 32, 64)), 2),
(((32, 32, 128), (32, 32, 64)), 2),
(((64, 32, 128), (64, 32, 64)), 2),
(((128, 32, 128), (128, 32, 64)), 2),
(((256, 32, 128), (256, 32, 64)), 2),
(((512, 32, 128), (512, 32, 64)), 2),
(((672, 32, 128), (672, 32, 64)), 2),
(((768, 32, 128), (768, 32, 64)), 2),
(((896, 32, 128), (896, 32, 64)), 2),
(((1024, 32, 128), (1024, 32, 64)), 2),
]) ])
def test_concat_Acc(shape_pair, dim): def test_concat_Acc(shape_pair, dim):
torch.manual_seed(1) torch.manual_seed(1)
shape1, shape2 = shape_pair shape1, shape2 = shape_pair
x = torch.randn(*shape1, device='cuda', dtype=torch.bfloat16) M = shape1[0]
y = torch.randn(*shape2, device='cuda', dtype=torch.bfloat16) N = shape1[1]
x_sizes = [M, N, 512]
x_strides = [512, 512*M, 1]
x_max_index = M * N * 512
x_required_length = x_max_index
x_data = torch.arange(x_required_length,device='cuda').bfloat16()
x = torch.as_strided(x_data, size=x_sizes, stride=x_strides)
# print("形状:", x.shape) # [4, 8, 512]
# print("步幅:", x.stride()) # (1536, 192, 1)
y_sizes = [M, N, 64]
y_strides = [1536*(N//8), 192, 1]
y_max_index = 1536*(N//8) * M
y_required_length = y_max_index
y_data = torch.arange(y_required_length,device='cuda').bfloat16()
y = torch.as_strided(y_data, size=y_sizes, stride=y_strides)
expected = torch.cat([x,y], dim=dim) expected = torch.cat([x,y], dim=dim)
result = concat_helper(x, y, dim=dim) result = concat_helper(x, y, dim=dim)
assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch" assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch"
...@@ -78,59 +86,31 @@ def test_concat_Acc(shape_pair, dim): ...@@ -78,59 +86,31 @@ def test_concat_Acc(shape_pair, dim):
@triton.jit
def concat_kernel_prefill(
A_ptr, B_ptr, C_ptr,
A_section_numel, B_section_numel, C_section_numel,
Per_block,
section_num,
BLOCK_SIZE: tl.constexpr
):
block_idx = tl.program_id(0)# 获取当前block的索引
for sub_section_index in range(Per_block//2):
sub_section_offset = block_idx * Per_block + sub_section_index * 2
if sub_section_offset <= section_num-1:
C_section_start = C_ptr + sub_section_offset * C_section_numel
A_section_start = A_ptr + sub_section_offset * A_section_numel
B_section_start = B_ptr + sub_section_offset * B_section_numel
Arrange_doubleA = tl.arange(0, 256)
mask = Arrange_doubleA < (256)
Arrange2 = (tl.arange(0, 128)[None,:] + tl.arange(0, 2)[:,None]).reshape(256)
val_from_A = tl.load(A_section_start + Arrange_doubleA)
tensorAsn = tl.full((256,), 0, tl.int32)
tensorAsn2 = tl.full((256,), (C_section_numel-1), tl.int32)
tensor_offsets = tl.where(Arrange_doubleA < A_section_numel,tensorAsn , tensorAsn2)
off = Arrange2 + tensor_offsets
tl.store(C_section_start + off,val_from_A,mask=mask)
Arrange_doubleB = tl.arange(0, 128)
mask = Arrange_doubleB < (B_section_numel*2)
val_from_B = tl.load(B_section_start + Arrange_doubleB,mask=mask)
Arrange3 = (tl.arange(0, 64)[None,:] + tl.arange(0, 2)[:,None]).reshape(128)
tensorAsn = tl.full((128,), A_section_numel, tl.int32)
tensorAsn2 = tl.full((128,), (C_section_numel + A_section_numel-1), tl.int32)
tensor_offsets = tl.where(Arrange_doubleB < B_section_numel,tensorAsn , tensorAsn2)
tl.store(C_section_start+ Arrange3 + tensor_offsets , val_from_B)
@triton.jit @triton.jit
def concat_kernel( def concat_kernel(
A_ptr, B_ptr, C_ptr, A_ptr, B_ptr, C_ptr,
A_section_numel, B_section_numel, C_section_numel, A_section_numel, B_section_numel, C_section_numel,
Per_block, Per_block,
section_num, section_num,
M,
N,
Astride_0,
Astride_1,
Astride_2,
Bstride_0,
Bstride_1,
Bstride_2,
BLOCK_SIZE: tl.constexpr BLOCK_SIZE: tl.constexpr
): ):
block_idx = tl.program_id(0) block_idx = tl.program_id(0)
for sub_section_index in range(Per_block): for sub_section_index in range(Per_block):
sub_offset = block_idx * Per_block + sub_section_index sub_offset = block_idx * Per_block + sub_section_index
M_idx = sub_offset // N
N_idx = sub_offset % N
if sub_offset <= section_num-1: if sub_offset <= section_num-1:
C_ptr_block_start = C_ptr + sub_offset * C_section_numel C_ptr_block_start = C_ptr + sub_offset * C_section_numel
A_ptr_block_start = A_ptr + sub_offset * A_section_numel A_ptr_block_start = A_ptr + M_idx * Astride_0 + N_idx * Astride_1
B_ptr_block_start = B_ptr + sub_offset * B_section_numel B_ptr_block_start = B_ptr + M_idx * Bstride_0 + N_idx * Bstride_1
for offset in range(0, A_section_numel, BLOCK_SIZE): for offset in range(0, A_section_numel, BLOCK_SIZE):
offset_idx = offset + tl.arange(0, BLOCK_SIZE) offset_idx = offset + tl.arange(0, BLOCK_SIZE)
...@@ -145,8 +125,7 @@ def concat_kernel( ...@@ -145,8 +125,7 @@ def concat_kernel(
tl.store(C_ptr_block_start + A_section_numel + offset_idx, val_from_B, mask=mask) tl.store(C_ptr_block_start + A_section_numel + offset_idx, val_from_B, mask=mask)
def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int): def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int):
A = A.contiguous()
B = B.contiguous()
output_shape = list(A.shape) output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim] output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype) C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
...@@ -154,38 +133,38 @@ def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int): ...@@ -154,38 +133,38 @@ def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int):
if dim!=0 : if dim!=0 :
block_num = reduce(lambda x, y: x * y, output_shape[:dim]) block_num = reduce(lambda x, y: x * y, output_shape[:dim])
Per_block = 1 Per_block = 1
unit_offset_A, unit_offset_B, unit_offset_C = A.stride(dim-1),B.stride(dim-1),C.stride(dim-1) unit_offset_A, unit_offset_B, unit_offset_C = A.shape[dim],B.shape[dim],C.shape[dim]
#case prefill if (A.shape[1]==8 and A.shape[0] > 512) or ( A.shape[1]==16 and A.shape[0] > 256):
if (A.shape[2] == 128 and B.shape[2] == 64 and A.shape[0] > 16):
Per_block = 8
num_blocks = math.ceil(block_num/Per_block)
concat_kernel_prefill[(num_blocks,)](
A, B, C,
unit_offset_A, unit_offset_B, unit_offset_C,
Per_block,
block_num,
BLOCK_SIZE=1024)
return C
else:
if (A.shape[1]==8 and A.shape[0] > 128) or ( A.shape[1]==16 and A.shape[0] > 96) or ( A.shape[1]==32 and A.shape[2] == 512 and A.shape[0] > 64):
Per_block = 2 Per_block = 2
if ( A.shape[1]==32 and A.shape[2] == 512 and A.shape[0] > 256):
Per_block = 8
num_blocks = math.ceil(block_num/Per_block) num_blocks = math.ceil(block_num/Per_block)
concat_kernel[(num_blocks,)]( concat_kernel[(num_blocks,)](
A, B, C, A, B, C,
unit_offset_A, unit_offset_B, unit_offset_C, unit_offset_A, unit_offset_B, unit_offset_C,
Per_block, Per_block,
block_num, block_num,
output_shape[0],
output_shape[1],
A.stride(0),
A.stride(1),
A.stride(2),
B.stride(0),
B.stride(1),
B.stride(2),
BLOCK_SIZE=1024) BLOCK_SIZE=1024)
return C return C
assert False, "not support" assert False, "not support"
configs = [] configs = []
configs.append( configs.append(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=['size'], x_names=['M','N'],
x_vals=[4,8,16,32,64,96,128,256,512,768,1024], x_vals=[(4,8),(8,8),(16,8),(32,8),(64,8),(96,8),(128,8),(256,8),(512,8),(768,8),(1024,8), \
(4,16),(8,16),(16,16),(32,16),(64,16),(96,16),(128,16),(256,16),(512,16),(768,16),(1024,16), \
(4,32),(8,32),(16,32),(32,32),(64,32),(96,32),(128,32),(256,32),(512,32),(768,32),(1024,32)],
x_log=True, x_log=True,
line_arg='provider', line_arg='provider',
line_vals=['triton', 'torch'], line_vals=['triton', 'torch'],
...@@ -198,42 +177,28 @@ configs.append( ...@@ -198,42 +177,28 @@ configs.append(
) )
@triton.testing.perf_report(configs) @triton.testing.perf_report(configs)
def benchmark(size, provider, dim): def benchmark(M, N, provider, dim):
x = torch.rand([size,8,512], device='cuda', dtype=torch.bfloat16)
y = torch.rand([size,8,64], device='cuda', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
@triton.testing.perf_report(configs) x_sizes = [M, N, 512]
def benchmark_16(size, provider, dim): x_strides = [512, 512*M, 1]
x = torch.rand([size,16,512], device='cuda', dtype=torch.bfloat16) x_max_index = M * N * 512
y = torch.rand([size,16,64], device='cuda', dtype=torch.bfloat16) x_required_length = x_max_index
quantiles = [0.5, 0.2, 0.8] x_data = torch.arange(x_required_length,device='cuda').bfloat16()
if provider == 'torch': x = torch.as_strided(x_data, size=x_sizes, stride=x_strides)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
@triton.testing.perf_report(configs) # print("形状:", x.shape) # [M, 8, 512]
def benchmark_32(size, provider, dim): # print("步幅:", x.stride()) # (512, 512*M, 1)
x = torch.rand([size,32,512], device='cuda', dtype=torch.bfloat16)
y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16) y_sizes = [M, N, 64]
quantiles = [0.5, 0.2, 0.8] y_strides = [1536*(N//8), 192, 1]
if provider == 'torch': y_max_index = 1536*(N//8) * M
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles) y_required_length = y_max_index
if provider == 'triton': y_data = torch.arange(y_required_length,device='cuda').bfloat16()
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles) y = torch.as_strided(y_data, size=y_sizes, stride=y_strides)
return (ms*1000), (max_ms*1000), (min_ms*1000)
# print("形状:", y.shape) # [M, 8, 64]
# print("步幅:", y.stride()) # (1536, 192, 1)
@triton.testing.perf_report(configs)
def benchmark_prefill(size, provider, dim):
x = torch.rand([size,32,128], device='cuda', dtype=torch.bfloat16)
y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == 'torch': if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
...@@ -241,8 +206,41 @@ def benchmark_prefill(size, provider, dim): ...@@ -241,8 +206,41 @@ def benchmark_prefill(size, provider, dim):
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000) return (ms*1000), (max_ms*1000), (min_ms*1000)
# @triton.testing.perf_report(configs)
# def benchmark_16(size, provider, dim):
# x = torch.rand([size,16,512], device='cuda', dtype=torch.bfloat16)
# y = torch.rand([size,16,64], device='cuda', dtype=torch.bfloat16)
# quantiles = [0.5, 0.2, 0.8]
# if provider == 'torch':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
# return (ms*1000), (max_ms*1000), (min_ms*1000)
# @triton.testing.perf_report(configs)
# def benchmark_32(size, provider, dim):
# x = torch.rand([size,32,512], device='cuda', dtype=torch.bfloat16)
# y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
# quantiles = [0.5, 0.2, 0.8]
# if provider == 'torch':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
# return (ms*1000), (max_ms*1000), (min_ms*1000)
# @triton.testing.perf_report(configs)
# def benchmark_prefill(size, provider, dim):
# x = torch.rand([size,32,128], device='cuda', dtype=torch.bfloat16)
# y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
# quantiles = [0.5, 0.2, 0.8]
# if provider == 'torch':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
# return (ms*1000), (max_ms*1000), (min_ms*1000)
if __name__ == '__main__': if __name__ == '__main__':
# benchmark.run(save_path="./triton_test_8",print_data=True) benchmark.run(save_path="./triton_test",print_data=True)
# benchmark_16.run(save_path="./triton_test_16",print_data=True) # benchmark_16.run(save_path="./triton_test_16",print_data=True)
# benchmark_32.run(save_path="./triton_test_32",print_data=True) # benchmark_32.run(save_path="./triton_test_32",print_data=True)
benchmark_prefill.run(save_path="./triton_test_prefill",print_data=True) # benchmark_prefill.run(save_path="./triton_test_prefill",print_data=True)
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment