Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2c169409
Commit
2c169409
authored
Sep 12, 2025
by
zhuwenwen
Browse files
update the cat implementation of triton's non contiguous memory for the decode phase
parent
e5f51b79
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
112 additions
and
114 deletions
+112
-114
vllm/v1/attention/backends/mla/concatv3Tritonfinal.py
vllm/v1/attention/backends/mla/concatv3Tritonfinal.py
+112
-114
No files found.
vllm/v1/attention/backends/mla/concatv3Tritonfinal.py
View file @
2c169409
...
...
@@ -50,26 +50,34 @@ import math
(((
4
,
32
,
128
),
(
4
,
32
,
64
)),
2
),
(((
8
,
32
,
128
),
(
8
,
32
,
64
)),
2
),
(((
16
,
32
,
128
),
(
16
,
32
,
64
)),
2
),
(((
32
,
32
,
128
),
(
32
,
32
,
64
)),
2
),
(((
64
,
32
,
128
),
(
64
,
32
,
64
)),
2
),
(((
128
,
32
,
128
),
(
128
,
32
,
64
)),
2
),
(((
256
,
32
,
128
),
(
256
,
32
,
64
)),
2
),
(((
512
,
32
,
128
),
(
512
,
32
,
64
)),
2
),
(((
672
,
32
,
128
),
(
672
,
32
,
64
)),
2
),
(((
768
,
32
,
128
),
(
768
,
32
,
64
)),
2
),
(((
896
,
32
,
128
),
(
896
,
32
,
64
)),
2
),
(((
1024
,
32
,
128
),
(
1024
,
32
,
64
)),
2
),
])
def
test_concat_Acc
(
shape_pair
,
dim
):
torch
.
manual_seed
(
1
)
shape1
,
shape2
=
shape_pair
x
=
torch
.
randn
(
*
shape1
,
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
randn
(
*
shape2
,
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
M
=
shape1
[
0
]
N
=
shape1
[
1
]
x_sizes
=
[
M
,
N
,
512
]
x_strides
=
[
512
,
512
*
M
,
1
]
x_max_index
=
M
*
N
*
512
x_required_length
=
x_max_index
x_data
=
torch
.
arange
(
x_required_length
,
device
=
'cuda'
).
bfloat16
()
x
=
torch
.
as_strided
(
x_data
,
size
=
x_sizes
,
stride
=
x_strides
)
# print("形状:", x.shape) # [4, 8, 512]
# print("步幅:", x.stride()) # (1536, 192, 1)
y_sizes
=
[
M
,
N
,
64
]
y_strides
=
[
1536
*
(
N
//
8
),
192
,
1
]
y_max_index
=
1536
*
(
N
//
8
)
*
M
y_required_length
=
y_max_index
y_data
=
torch
.
arange
(
y_required_length
,
device
=
'cuda'
).
bfloat16
()
y
=
torch
.
as_strided
(
y_data
,
size
=
y_sizes
,
stride
=
y_strides
)
expected
=
torch
.
cat
([
x
,
y
],
dim
=
dim
)
result
=
concat_helper
(
x
,
y
,
dim
=
dim
)
assert
torch
.
allclose
(
result
,
expected
,
rtol
=
1e-5
,
atol
=
1e-5
),
"Mismatch"
...
...
@@ -78,59 +86,31 @@ def test_concat_Acc(shape_pair, dim):
@
triton
.
jit
def
concat_kernel_prefill
(
A_ptr
,
B_ptr
,
C_ptr
,
A_section_numel
,
B_section_numel
,
C_section_numel
,
Per_block
,
section_num
,
BLOCK_SIZE
:
tl
.
constexpr
):
block_idx
=
tl
.
program_id
(
0
)
# 获取当前block的索引
for
sub_section_index
in
range
(
Per_block
//
2
):
sub_section_offset
=
block_idx
*
Per_block
+
sub_section_index
*
2
if
sub_section_offset
<=
section_num
-
1
:
C_section_start
=
C_ptr
+
sub_section_offset
*
C_section_numel
A_section_start
=
A_ptr
+
sub_section_offset
*
A_section_numel
B_section_start
=
B_ptr
+
sub_section_offset
*
B_section_numel
Arrange_doubleA
=
tl
.
arange
(
0
,
256
)
mask
=
Arrange_doubleA
<
(
256
)
Arrange2
=
(
tl
.
arange
(
0
,
128
)[
None
,:]
+
tl
.
arange
(
0
,
2
)[:,
None
]).
reshape
(
256
)
val_from_A
=
tl
.
load
(
A_section_start
+
Arrange_doubleA
)
tensorAsn
=
tl
.
full
((
256
,),
0
,
tl
.
int32
)
tensorAsn2
=
tl
.
full
((
256
,),
(
C_section_numel
-
1
),
tl
.
int32
)
tensor_offsets
=
tl
.
where
(
Arrange_doubleA
<
A_section_numel
,
tensorAsn
,
tensorAsn2
)
off
=
Arrange2
+
tensor_offsets
tl
.
store
(
C_section_start
+
off
,
val_from_A
,
mask
=
mask
)
Arrange_doubleB
=
tl
.
arange
(
0
,
128
)
mask
=
Arrange_doubleB
<
(
B_section_numel
*
2
)
val_from_B
=
tl
.
load
(
B_section_start
+
Arrange_doubleB
,
mask
=
mask
)
Arrange3
=
(
tl
.
arange
(
0
,
64
)[
None
,:]
+
tl
.
arange
(
0
,
2
)[:,
None
]).
reshape
(
128
)
tensorAsn
=
tl
.
full
((
128
,),
A_section_numel
,
tl
.
int32
)
tensorAsn2
=
tl
.
full
((
128
,),
(
C_section_numel
+
A_section_numel
-
1
),
tl
.
int32
)
tensor_offsets
=
tl
.
where
(
Arrange_doubleB
<
B_section_numel
,
tensorAsn
,
tensorAsn2
)
tl
.
store
(
C_section_start
+
Arrange3
+
tensor_offsets
,
val_from_B
)
@
triton
.
jit
def
concat_kernel
(
A_ptr
,
B_ptr
,
C_ptr
,
A_section_numel
,
B_section_numel
,
C_section_numel
,
Per_block
,
section_num
,
section_num
,
M
,
N
,
Astride_0
,
Astride_1
,
Astride_2
,
Bstride_0
,
Bstride_1
,
Bstride_2
,
BLOCK_SIZE
:
tl
.
constexpr
):
block_idx
=
tl
.
program_id
(
0
)
for
sub_section_index
in
range
(
Per_block
):
sub_offset
=
block_idx
*
Per_block
+
sub_section_index
sub_offset
=
block_idx
*
Per_block
+
sub_section_index
M_idx
=
sub_offset
//
N
N_idx
=
sub_offset
%
N
if
sub_offset
<=
section_num
-
1
:
C_ptr_block_start
=
C_ptr
+
sub_offset
*
C_section_numel
A_ptr_block_start
=
A_ptr
+
sub_offset
*
A_section_numel
B_ptr_block_start
=
B_ptr
+
sub_offset
*
B_section_numel
A_ptr_block_start
=
A_ptr
+
M_idx
*
Astride_0
+
N_idx
*
Astride_1
B_ptr_block_start
=
B_ptr
+
M_idx
*
Bstride_0
+
N_idx
*
Bstride_1
for
offset
in
range
(
0
,
A_section_numel
,
BLOCK_SIZE
):
offset_idx
=
offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
...
...
@@ -145,8 +125,7 @@ def concat_kernel(
tl
.
store
(
C_ptr_block_start
+
A_section_numel
+
offset_idx
,
val_from_B
,
mask
=
mask
)
def
concat_helper
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
):
A
=
A
.
contiguous
()
B
=
B
.
contiguous
()
output_shape
=
list
(
A
.
shape
)
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
C
=
torch
.
empty
(
output_shape
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
...
...
@@ -154,38 +133,38 @@ def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int):
if
dim
!=
0
:
block_num
=
reduce
(
lambda
x
,
y
:
x
*
y
,
output_shape
[:
dim
])
Per_block
=
1
unit_offset_A
,
unit_offset_B
,
unit_offset_C
=
A
.
stride
(
dim
-
1
),
B
.
stride
(
dim
-
1
),
C
.
stride
(
dim
-
1
)
#case prefill
if
(
A
.
shape
[
2
]
==
128
and
B
.
shape
[
2
]
==
64
and
A
.
shape
[
0
]
>
16
):
unit_offset_A
,
unit_offset_B
,
unit_offset_C
=
A
.
shape
[
dim
],
B
.
shape
[
dim
],
C
.
shape
[
dim
]
if
(
A
.
shape
[
1
]
==
8
and
A
.
shape
[
0
]
>
512
)
or
(
A
.
shape
[
1
]
==
16
and
A
.
shape
[
0
]
>
256
):
Per_block
=
2
if
(
A
.
shape
[
1
]
==
32
and
A
.
shape
[
2
]
==
512
and
A
.
shape
[
0
]
>
256
):
Per_block
=
8
num_blocks
=
math
.
ceil
(
block_num
/
Per_block
)
concat_kernel_prefill
[(
num_blocks
,)](
A
,
B
,
C
,
unit_offset_A
,
unit_offset_B
,
unit_offset_C
,
Per_block
,
block_num
,
BLOCK_SIZE
=
1024
)
return
C
else
:
if
(
A
.
shape
[
1
]
==
8
and
A
.
shape
[
0
]
>
128
)
or
(
A
.
shape
[
1
]
==
16
and
A
.
shape
[
0
]
>
96
)
or
(
A
.
shape
[
1
]
==
32
and
A
.
shape
[
2
]
==
512
and
A
.
shape
[
0
]
>
64
):
Per_block
=
2
num_blocks
=
math
.
ceil
(
block_num
/
Per_block
)
concat_kernel
[(
num_blocks
,)](
num_blocks
=
math
.
ceil
(
block_num
/
Per_block
)
concat_kernel
[(
num_blocks
,)](
A
,
B
,
C
,
unit_offset_A
,
unit_offset_B
,
unit_offset_C
,
Per_block
,
block_num
,
block_num
,
output_shape
[
0
],
output_shape
[
1
],
A
.
stride
(
0
),
A
.
stride
(
1
),
A
.
stride
(
2
),
B
.
stride
(
0
),
B
.
stride
(
1
),
B
.
stride
(
2
),
BLOCK_SIZE
=
1024
)
return
C
return
C
assert
False
,
"not support"
configs
=
[]
configs
.
append
(
triton
.
testing
.
Benchmark
(
x_names
=
[
'size'
],
x_vals
=
[
4
,
8
,
16
,
32
,
64
,
96
,
128
,
256
,
512
,
768
,
1024
],
x_names
=
[
'M'
,
'N'
],
x_vals
=
[(
4
,
8
),(
8
,
8
),(
16
,
8
),(
32
,
8
),(
64
,
8
),(
96
,
8
),(
128
,
8
),(
256
,
8
),(
512
,
8
),(
768
,
8
),(
1024
,
8
),
\
(
4
,
16
),(
8
,
16
),(
16
,
16
),(
32
,
16
),(
64
,
16
),(
96
,
16
),(
128
,
16
),(
256
,
16
),(
512
,
16
),(
768
,
16
),(
1024
,
16
),
\
(
4
,
32
),(
8
,
32
),(
16
,
32
),(
32
,
32
),(
64
,
32
),(
96
,
32
),(
128
,
32
),(
256
,
32
),(
512
,
32
),(
768
,
32
),(
1024
,
32
)],
x_log
=
True
,
line_arg
=
'provider'
,
line_vals
=
[
'triton'
,
'torch'
],
...
...
@@ -198,31 +177,28 @@ configs.append(
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark
(
size
,
provider
,
dim
):
x
=
torch
.
rand
([
size
,
8
,
512
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
rand
([
size
,
8
,
64
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
def
benchmark
(
M
,
N
,
provider
,
dim
):
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark_16
(
size
,
provider
,
dim
):
x
=
torch
.
rand
([
size
,
16
,
512
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
rand
([
size
,
16
,
64
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
x_sizes
=
[
M
,
N
,
512
]
x_strides
=
[
512
,
512
*
M
,
1
]
x_max_index
=
M
*
N
*
512
x_required_length
=
x_max_index
x_data
=
torch
.
arange
(
x_required_length
,
device
=
'cuda'
).
bfloat16
()
x
=
torch
.
as_strided
(
x_data
,
size
=
x_sizes
,
stride
=
x_strides
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark_32
(
size
,
provider
,
dim
):
x
=
torch
.
rand
([
size
,
32
,
512
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
rand
([
size
,
32
,
64
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
# print("形状:", x.shape) # [M, 8, 512]
# print("步幅:", x.stride()) # (512, 512*M, 1)
y_sizes
=
[
M
,
N
,
64
]
y_strides
=
[
1536
*
(
N
//
8
),
192
,
1
]
y_max_index
=
1536
*
(
N
//
8
)
*
M
y_required_length
=
y_max_index
y_data
=
torch
.
arange
(
y_required_length
,
device
=
'cuda'
).
bfloat16
()
y
=
torch
.
as_strided
(
y_data
,
size
=
y_sizes
,
stride
=
y_strides
)
# print("形状:", y.shape) # [M, 8, 64]
# print("步幅:", y.stride()) # (1536, 192, 1)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
...
...
@@ -230,19 +206,41 @@ def benchmark_32(size, provider, dim):
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark_prefill
(
size
,
provider
,
dim
):
x
=
torch
.
rand
([
size
,
32
,
128
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
rand
([
size
,
32
,
64
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
# @triton.testing.perf_report(configs)
# def benchmark_16(size, provider, dim):
# x = torch.rand([size,16,512], device='cuda', dtype=torch.bfloat16)
# y = torch.rand([size,16,64], device='cuda', dtype=torch.bfloat16)
# quantiles = [0.5, 0.2, 0.8]
# if provider == 'torch':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
# return (ms*1000), (max_ms*1000), (min_ms*1000)
# @triton.testing.perf_report(configs)
# def benchmark_32(size, provider, dim):
# x = torch.rand([size,32,512], device='cuda', dtype=torch.bfloat16)
# y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
# quantiles = [0.5, 0.2, 0.8]
# if provider == 'torch':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
# return (ms*1000), (max_ms*1000), (min_ms*1000)
# @triton.testing.perf_report(configs)
# def benchmark_prefill(size, provider, dim):
# x = torch.rand([size,32,128], device='cuda', dtype=torch.bfloat16)
# y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
# quantiles = [0.5, 0.2, 0.8]
# if provider == 'torch':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
# return (ms*1000), (max_ms*1000), (min_ms*1000)
if
__name__
==
'__main__'
:
#
benchmark.run(save_path="./triton_test
_8
",print_data=True)
benchmark
.
run
(
save_path
=
"./triton_test"
,
print_data
=
True
)
# benchmark_16.run(save_path="./triton_test_16",print_data=True)
# benchmark_32.run(save_path="./triton_test_32",print_data=True)
benchmark_prefill
.
run
(
save_path
=
"./triton_test_prefill"
,
print_data
=
True
)
\ No newline at end of file
# benchmark_prefill.run(save_path="./triton_test_prefill",print_data=True)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment