Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f572ca96
Commit
f572ca96
authored
Sep 10, 2025
by
zhuwenwen
Browse files
update triton kernel to optimize torch cat for ds prefill
parent
787c2557
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
98 additions
and
16 deletions
+98
-16
vllm/v1/attention/backends/mla/concatv3Tritonfinal.py
vllm/v1/attention/backends/mla/concatv3Tritonfinal.py
+98
-16
No files found.
vllm/v1/attention/backends/mla/concatv3Tritonfinal.py
View file @
f572ca96
...
...
@@ -48,6 +48,21 @@ import math
(((
896
,
32
,
512
),
(
896
,
32
,
64
)),
2
),
(((
1024
,
32
,
512
),
(
1024
,
32
,
64
)),
2
),
(((
4
,
32
,
128
),
(
4
,
32
,
64
)),
2
),
(((
8
,
32
,
128
),
(
8
,
32
,
64
)),
2
),
(((
16
,
32
,
128
),
(
16
,
32
,
64
)),
2
),
(((
32
,
32
,
128
),
(
32
,
32
,
64
)),
2
),
(((
64
,
32
,
128
),
(
64
,
32
,
64
)),
2
),
(((
128
,
32
,
128
),
(
128
,
32
,
64
)),
2
),
(((
256
,
32
,
128
),
(
256
,
32
,
64
)),
2
),
(((
512
,
32
,
128
),
(
512
,
32
,
64
)),
2
),
(((
672
,
32
,
128
),
(
672
,
32
,
64
)),
2
),
(((
768
,
32
,
128
),
(
768
,
32
,
64
)),
2
),
(((
896
,
32
,
128
),
(
896
,
32
,
64
)),
2
),
(((
1024
,
32
,
128
),
(
1024
,
32
,
64
)),
2
),
])
def
test_concat_Acc
(
shape_pair
,
dim
):
...
...
@@ -60,6 +75,47 @@ def test_concat_Acc(shape_pair, dim):
assert
torch
.
allclose
(
result
,
expected
,
rtol
=
1e-5
,
atol
=
1e-5
),
"Mismatch"
@
triton
.
jit
def
concat_kernel_prefill
(
A_ptr
,
B_ptr
,
C_ptr
,
A_section_numel
,
B_section_numel
,
C_section_numel
,
Per_block
,
section_num
,
BLOCK_SIZE
:
tl
.
constexpr
):
block_idx
=
tl
.
program_id
(
0
)
# 获取当前block的索引
for
sub_section_index
in
range
(
Per_block
//
2
):
sub_section_offset
=
block_idx
*
Per_block
+
sub_section_index
*
2
if
sub_section_offset
<=
section_num
-
1
:
C_section_start
=
C_ptr
+
sub_section_offset
*
C_section_numel
A_section_start
=
A_ptr
+
sub_section_offset
*
A_section_numel
B_section_start
=
B_ptr
+
sub_section_offset
*
B_section_numel
Arrange_doubleA
=
tl
.
arange
(
0
,
256
)
mask
=
Arrange_doubleA
<
(
256
)
Arrange2
=
(
tl
.
arange
(
0
,
128
)[
None
,:]
+
tl
.
arange
(
0
,
2
)[:,
None
]).
reshape
(
256
)
val_from_A
=
tl
.
load
(
A_section_start
+
Arrange_doubleA
)
tensorAsn
=
tl
.
full
((
256
,),
0
,
tl
.
int32
)
tensorAsn2
=
tl
.
full
((
256
,),
(
C_section_numel
-
1
),
tl
.
int32
)
tensor_offsets
=
tl
.
where
(
Arrange_doubleA
<
A_section_numel
,
tensorAsn
,
tensorAsn2
)
off
=
Arrange2
+
tensor_offsets
tl
.
store
(
C_section_start
+
off
,
val_from_A
,
mask
=
mask
)
Arrange_doubleB
=
tl
.
arange
(
0
,
128
)
mask
=
Arrange_doubleB
<
(
B_section_numel
*
2
)
val_from_B
=
tl
.
load
(
B_section_start
+
Arrange_doubleB
,
mask
=
mask
)
Arrange3
=
(
tl
.
arange
(
0
,
64
)[
None
,:]
+
tl
.
arange
(
0
,
2
)[:,
None
]).
reshape
(
128
)
tensorAsn
=
tl
.
full
((
128
,),
A_section_numel
,
tl
.
int32
)
tensorAsn2
=
tl
.
full
((
128
,),
(
C_section_numel
+
A_section_numel
-
1
),
tl
.
int32
)
tensor_offsets
=
tl
.
where
(
Arrange_doubleB
<
B_section_numel
,
tensorAsn
,
tensorAsn2
)
tl
.
store
(
C_section_start
+
Arrange3
+
tensor_offsets
,
val_from_B
)
@
triton
.
jit
def
concat_kernel
(
A_ptr
,
B_ptr
,
C_ptr
,
...
...
@@ -94,11 +150,25 @@ def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int):
output_shape
=
list
(
A
.
shape
)
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
C
=
torch
.
empty
(
output_shape
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
if
dim
!=
0
:
block_num
=
reduce
(
lambda
x
,
y
:
x
*
y
,
output_shape
[:
dim
])
unit_offset_A
,
unit_offset_B
,
unit_offset_C
=
A
.
stride
(
dim
-
1
),
B
.
stride
(
dim
-
1
),
C
.
stride
(
dim
-
1
)
Per_block
=
1
if
(
A
.
shape
[
1
]
==
8
and
A
.
shape
[
0
]
>
128
)
or
(
A
.
shape
[
1
]
==
16
and
A
.
shape
[
0
]
>
96
)
or
(
A
.
shape
[
1
]
==
32
and
A
.
shape
[
0
]
>
64
):
unit_offset_A
,
unit_offset_B
,
unit_offset_C
=
A
.
stride
(
dim
-
1
),
B
.
stride
(
dim
-
1
),
C
.
stride
(
dim
-
1
)
#case prefill
if
(
A
.
shape
[
2
]
==
128
and
B
.
shape
[
2
]
==
64
and
A
.
shape
[
0
]
>
16
):
Per_block
=
8
num_blocks
=
math
.
ceil
(
block_num
/
Per_block
)
concat_kernel_prefill
[(
num_blocks
,)](
A
,
B
,
C
,
unit_offset_A
,
unit_offset_B
,
unit_offset_C
,
Per_block
,
block_num
,
BLOCK_SIZE
=
1024
)
return
C
else
:
if
(
A
.
shape
[
1
]
==
8
and
A
.
shape
[
0
]
>
128
)
or
(
A
.
shape
[
1
]
==
16
and
A
.
shape
[
0
]
>
96
)
or
(
A
.
shape
[
1
]
==
32
and
A
.
shape
[
2
]
==
512
and
A
.
shape
[
0
]
>
64
):
Per_block
=
2
num_blocks
=
math
.
ceil
(
block_num
/
Per_block
)
concat_kernel
[(
num_blocks
,)](
...
...
@@ -160,7 +230,19 @@ def benchmark_32(size, provider, dim):
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark_prefill
(
size
,
provider
,
dim
):
x
=
torch
.
rand
([
size
,
32
,
128
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
rand
([
size
,
32
,
64
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
if
__name__
==
'__main__'
:
benchmark
.
run
(
save_path
=
"./triton_test_8"
,
print_data
=
True
)
benchmark_16
.
run
(
save_path
=
"./triton_test_16"
,
print_data
=
True
)
benchmark_32
.
run
(
save_path
=
"./triton_test_32"
,
print_data
=
True
)
\ No newline at end of file
# benchmark.run(save_path="./triton_test_8",print_data=True)
# benchmark_16.run(save_path="./triton_test_16",print_data=True)
# benchmark_32.run(save_path="./triton_test_32",print_data=True)
benchmark_prefill
.
run
(
save_path
=
"./triton_test_prefill"
,
print_data
=
True
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment