Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b1329ff2
Commit
b1329ff2
authored
Sep 13, 2025
by
zhuwenwen
Browse files
update triton cat scheduling
update the default values of VLLM_USE_TRITON_CAT and VLLM_USE_LIGHT_OP to True
parent
2c169409
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
263 additions
and
7 deletions
+263
-7
vllm/envs.py
vllm/envs.py
+2
-2
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+1
-1
vllm/v1/attention/backends/mla/concatv3Tritonfinalv2.py
vllm/v1/attention/backends/mla/concatv3Tritonfinalv2.py
+250
-0
vllm/v1/attention/backends/mla/concatv4_decode_only.py
vllm/v1/attention/backends/mla/concatv4_decode_only.py
+3
-1
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+7
-3
No files found.
vllm/envs.py
View file @
b1329ff2
...
@@ -1093,11 +1093,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1093,11 +1093,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use global cache for moe
# vLLM will use global cache for moe
"VLLM_USE_LIGHT_OP"
:
"VLLM_USE_LIGHT_OP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHT_OP"
,
"
Fals
e"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHT_OP"
,
"
Tru
e"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use global cache for moe
# vLLM will use global cache for moe
"VLLM_USE_TRITON_CAT"
:
"VLLM_USE_TRITON_CAT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_CAT"
,
"
Fals
e"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_CAT"
,
"
Tru
e"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
}
}
...
...
vllm/v1/attention/backends/mla/common.py
View file @
b1329ff2
...
@@ -216,7 +216,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
...
@@ -216,7 +216,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata
)
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.attention.backends.mla.concatv3Tritonfinal
import
concat_helper
from
vllm.v1.attention.backends.mla.concatv3Tritonfinal
v2
import
concat_helper
try
:
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
...
vllm/v1/attention/backends/mla/concatv3Tritonfinalv2.py
0 → 100644
View file @
b1329ff2
import
triton
import
triton.language
as
tl
import
torch
from
functools
import
reduce
import
pytest
import
torch
import
math
@
pytest
.
mark
.
parametrize
(
"shape_pair,dim"
,
[
(((
4
,
8
,
512
),
(
4
,
8
,
64
)),
2
),
(((
8
,
8
,
512
),
(
8
,
8
,
64
)),
2
),
(((
16
,
8
,
512
),
(
16
,
8
,
64
)),
2
),
(((
32
,
8
,
512
),
(
32
,
8
,
64
)),
2
),
(((
64
,
8
,
512
),
(
64
,
8
,
64
)),
2
),
(((
128
,
8
,
512
),
(
128
,
8
,
64
)),
2
),
(((
256
,
8
,
512
),
(
256
,
8
,
64
)),
2
),
(((
512
,
8
,
512
),
(
512
,
8
,
64
)),
2
),
(((
672
,
8
,
512
),
(
672
,
8
,
64
)),
2
),
(((
768
,
8
,
512
),
(
768
,
8
,
64
)),
2
),
(((
896
,
8
,
512
),
(
896
,
8
,
64
)),
2
),
(((
1024
,
8
,
512
),
(
1024
,
8
,
64
)),
2
),
(((
4
,
16
,
512
),
(
4
,
16
,
64
)),
2
),
(((
8
,
16
,
512
),
(
8
,
16
,
64
)),
2
),
(((
16
,
16
,
512
),
(
16
,
16
,
64
)),
2
),
(((
32
,
16
,
512
),
(
32
,
16
,
64
)),
2
),
(((
64
,
16
,
512
),
(
64
,
16
,
64
)),
2
),
(((
128
,
16
,
512
),
(
128
,
16
,
64
)),
2
),
(((
256
,
16
,
512
),
(
256
,
16
,
64
)),
2
),
(((
512
,
16
,
512
),
(
512
,
16
,
64
)),
2
),
(((
672
,
16
,
512
),
(
672
,
16
,
64
)),
2
),
(((
768
,
16
,
512
),
(
768
,
16
,
64
)),
2
),
(((
896
,
16
,
512
),
(
896
,
16
,
64
)),
2
),
(((
1024
,
16
,
512
),
(
1024
,
16
,
64
)),
2
),
(((
4
,
32
,
512
),
(
4
,
32
,
64
)),
2
),
(((
8
,
32
,
512
),
(
8
,
32
,
64
)),
2
),
(((
16
,
32
,
512
),
(
16
,
32
,
64
)),
2
),
(((
32
,
32
,
512
),
(
32
,
32
,
64
)),
2
),
(((
64
,
32
,
512
),
(
64
,
32
,
64
)),
2
),
(((
128
,
32
,
512
),
(
128
,
32
,
64
)),
2
),
(((
256
,
32
,
512
),
(
256
,
32
,
64
)),
2
),
(((
512
,
32
,
512
),
(
512
,
32
,
64
)),
2
),
(((
672
,
32
,
512
),
(
672
,
32
,
64
)),
2
),
(((
768
,
32
,
512
),
(
768
,
32
,
64
)),
2
),
(((
896
,
32
,
512
),
(
896
,
32
,
64
)),
2
),
(((
1024
,
32
,
512
),
(
1024
,
32
,
64
)),
2
),
(((
4
,
32
,
128
),
(
4
,
32
,
64
)),
2
),
(((
8
,
32
,
128
),
(
8
,
32
,
64
)),
2
),
(((
16
,
32
,
128
),
(
16
,
32
,
64
)),
2
),
(((
32
,
32
,
128
),
(
32
,
32
,
64
)),
2
),
(((
64
,
32
,
128
),
(
64
,
32
,
64
)),
2
),
(((
128
,
32
,
128
),
(
128
,
32
,
64
)),
2
),
(((
256
,
32
,
128
),
(
256
,
32
,
64
)),
2
),
(((
512
,
32
,
128
),
(
512
,
32
,
64
)),
2
),
(((
672
,
32
,
128
),
(
672
,
32
,
64
)),
2
),
(((
768
,
32
,
128
),
(
768
,
32
,
64
)),
2
),
(((
896
,
32
,
128
),
(
896
,
32
,
64
)),
2
),
(((
1024
,
32
,
128
),
(
1024
,
32
,
64
)),
2
),
])
def
test_concat_Acc
(
shape_pair
,
dim
):
torch
.
manual_seed
(
1
)
shape1
,
shape2
=
shape_pair
x
=
torch
.
randn
(
*
shape1
,
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
randn
(
*
shape2
,
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
expected
=
torch
.
cat
([
x
,
y
],
dim
=
dim
)
result
=
concat_helper
(
x
,
y
,
dim
=
dim
)
assert
torch
.
allclose
(
result
,
expected
,
rtol
=
1e-5
,
atol
=
1e-5
),
"Mismatch"
@
triton
.
jit
def
concat_kernel_prefill
(
A_ptr
,
B_ptr
,
C_ptr
,
A_section_numel
,
B_section_numel
,
C_section_numel
,
Per_block
,
section_num
,
BLOCK_SIZE
:
tl
.
constexpr
):
block_idx
=
tl
.
program_id
(
0
)
# 获取当前block的索引
for
sub_section_index
in
range
(
Per_block
//
2
):
sub_section_offset
=
block_idx
*
Per_block
+
sub_section_index
*
2
if
sub_section_offset
<=
section_num
-
1
:
C_section_start
=
C_ptr
+
sub_section_offset
*
C_section_numel
A_section_start
=
A_ptr
+
sub_section_offset
*
A_section_numel
B_section_start
=
B_ptr
+
sub_section_offset
*
B_section_numel
Arrange_doubleA
=
tl
.
arange
(
0
,
256
)
mask
=
Arrange_doubleA
<
(
256
)
Arrange2
=
(
tl
.
arange
(
0
,
128
)[
None
,:]
+
tl
.
arange
(
0
,
2
)[:,
None
]).
reshape
(
256
)
val_from_A
=
tl
.
load
(
A_section_start
+
Arrange_doubleA
)
tensorAsn
=
tl
.
full
((
256
,),
0
,
tl
.
int32
)
tensorAsn2
=
tl
.
full
((
256
,),
(
C_section_numel
-
1
),
tl
.
int32
)
tensor_offsets
=
tl
.
where
(
Arrange_doubleA
<
A_section_numel
,
tensorAsn
,
tensorAsn2
)
off
=
Arrange2
+
tensor_offsets
tl
.
store
(
C_section_start
+
off
,
val_from_A
,
mask
=
mask
)
Arrange_doubleB
=
tl
.
arange
(
0
,
128
)
mask
=
Arrange_doubleB
<
(
B_section_numel
*
2
)
val_from_B
=
tl
.
load
(
B_section_start
+
Arrange_doubleB
,
mask
=
mask
)
Arrange3
=
(
tl
.
arange
(
0
,
64
)[
None
,:]
+
tl
.
arange
(
0
,
2
)[:,
None
]).
reshape
(
128
)
tensorAsn
=
tl
.
full
((
128
,),
A_section_numel
,
tl
.
int32
)
tensorAsn2
=
tl
.
full
((
128
,),
(
C_section_numel
+
A_section_numel
-
1
),
tl
.
int32
)
tensor_offsets
=
tl
.
where
(
Arrange_doubleB
<
B_section_numel
,
tensorAsn
,
tensorAsn2
)
tl
.
store
(
C_section_start
+
Arrange3
+
tensor_offsets
,
val_from_B
)
@
triton
.
jit
def
concat_kernel
(
A_ptr
,
B_ptr
,
C_ptr
,
A_section_numel
,
B_section_numel
,
C_section_numel
,
Per_block
,
section_num
,
BLOCK_SIZE
:
tl
.
constexpr
):
block_idx
=
tl
.
program_id
(
0
)
for
sub_section_index
in
range
(
Per_block
):
sub_offset
=
block_idx
*
Per_block
+
sub_section_index
if
sub_offset
<=
section_num
-
1
:
C_ptr_block_start
=
C_ptr
+
sub_offset
*
C_section_numel
A_ptr_block_start
=
A_ptr
+
sub_offset
*
A_section_numel
B_ptr_block_start
=
B_ptr
+
sub_offset
*
B_section_numel
for
offset
in
range
(
0
,
A_section_numel
,
BLOCK_SIZE
):
offset_idx
=
offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offset_idx
<
A_section_numel
val_from_A
=
tl
.
load
(
A_ptr_block_start
+
offset_idx
,
mask
=
mask
)
tl
.
store
(
C_ptr_block_start
+
offset_idx
,
val_from_A
,
mask
=
mask
)
for
offset
in
range
(
0
,
B_section_numel
,
BLOCK_SIZE
):
offset_idx
=
offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offset_idx
<
B_section_numel
val_from_B
=
tl
.
load
(
B_ptr_block_start
+
offset_idx
,
mask
=
mask
)
tl
.
store
(
C_ptr_block_start
+
A_section_numel
+
offset_idx
,
val_from_B
,
mask
=
mask
)
def
concat_helper
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
):
A
=
A
.
contiguous
()
B
=
B
.
contiguous
()
output_shape
=
list
(
A
.
shape
)
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
C
=
torch
.
empty
(
output_shape
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
if
dim
!=
0
:
block_num
=
reduce
(
lambda
x
,
y
:
x
*
y
,
output_shape
[:
dim
])
Per_block
=
1
unit_offset_A
,
unit_offset_B
,
unit_offset_C
=
A
.
stride
(
dim
-
1
),
B
.
stride
(
dim
-
1
),
C
.
stride
(
dim
-
1
)
#case prefill
if
(
A
.
shape
[
2
]
==
128
and
B
.
shape
[
2
]
==
64
and
A
.
shape
[
0
]
>
16
):
Per_block
=
8
num_blocks
=
math
.
ceil
(
block_num
/
Per_block
)
concat_kernel_prefill
[(
num_blocks
,)](
A
,
B
,
C
,
unit_offset_A
,
unit_offset_B
,
unit_offset_C
,
Per_block
,
block_num
,
BLOCK_SIZE
=
1024
)
return
C
else
:
if
(
A
.
shape
[
1
]
==
8
and
A
.
shape
[
0
]
>
128
)
or
(
A
.
shape
[
1
]
==
16
and
A
.
shape
[
0
]
>
96
)
or
(
A
.
shape
[
1
]
==
32
and
A
.
shape
[
2
]
==
512
and
A
.
shape
[
0
]
>
64
):
Per_block
=
2
num_blocks
=
math
.
ceil
(
block_num
/
Per_block
)
concat_kernel
[(
num_blocks
,)](
A
,
B
,
C
,
unit_offset_A
,
unit_offset_B
,
unit_offset_C
,
Per_block
,
block_num
,
BLOCK_SIZE
=
1024
)
return
C
assert
False
,
"not support"
configs
=
[]
configs
.
append
(
triton
.
testing
.
Benchmark
(
x_names
=
[
'size'
],
x_vals
=
[
4
,
8
,
16
,
32
,
64
,
96
,
128
,
256
,
512
,
768
,
1024
],
x_log
=
True
,
line_arg
=
'provider'
,
line_vals
=
[
'triton'
,
'torch'
],
line_names
=
[
'Triton'
,
'Torch'
],
styles
=
[(
'blue'
,
'-'
),
(
'green'
,
'-'
)],
ylabel
=
's'
,
plot_name
=
'concat-dim2'
,
args
=
{
"dim"
:
2
},
),
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark
(
size
,
provider
,
dim
):
x
=
torch
.
rand
([
size
,
8
,
512
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
rand
([
size
,
8
,
64
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark_16
(
size
,
provider
,
dim
):
x
=
torch
.
rand
([
size
,
16
,
512
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
rand
([
size
,
16
,
64
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark_32
(
size
,
provider
,
dim
):
x
=
torch
.
rand
([
size
,
32
,
512
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
rand
([
size
,
32
,
64
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark_prefill
(
size
,
provider
,
dim
):
x
=
torch
.
rand
([
size
,
32
,
128
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
rand
([
size
,
32
,
64
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
if
__name__
==
'__main__'
:
# benchmark.run(save_path="./triton_test_8",print_data=True)
# benchmark_16.run(save_path="./triton_test_16",print_data=True)
# benchmark_32.run(save_path="./triton_test_32",print_data=True)
benchmark_prefill
.
run
(
save_path
=
"./triton_test_prefill"
,
print_data
=
True
)
\ No newline at end of file
vllm/v1/attention/backends/mla/concatv
3Tritonfinal
.py
→
vllm/v1/attention/backends/mla/concatv
4_decode_only
.py
View file @
b1329ff2
...
@@ -243,4 +243,6 @@ if __name__ == '__main__':
...
@@ -243,4 +243,6 @@ if __name__ == '__main__':
benchmark
.
run
(
save_path
=
"./triton_test"
,
print_data
=
True
)
benchmark
.
run
(
save_path
=
"./triton_test"
,
print_data
=
True
)
# benchmark_16.run(save_path="./triton_test_16",print_data=True)
# benchmark_16.run(save_path="./triton_test_16",print_data=True)
# benchmark_32.run(save_path="./triton_test_32",print_data=True)
# benchmark_32.run(save_path="./triton_test_32",print_data=True)
# benchmark_prefill.run(save_path="./triton_test_prefill",print_data=True)
# benchmark_prefill.run(save_path="./triton_test_prefill",print_data=True)
\ No newline at end of file
\ No newline at end of file
vllm/v1/attention/backends/mla/flashmla.py
View file @
b1329ff2
...
@@ -20,7 +20,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
...
@@ -20,7 +20,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm
import
envs
from
vllm
import
envs
from
vllm.v1.attention.backends.mla.concatv
3Tritonfinal
import
concat_helper
from
vllm.v1.attention.backends.mla.concatv
4_decode_only
import
concat_helper
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -167,8 +167,12 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -167,8 +167,12 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert
attn_metadata
.
decode
is
not
None
assert
attn_metadata
.
decode
is
not
None
if
envs
.
VLLM_USE_TRITON_CAT
:
if
envs
.
VLLM_USE_TRITON_CAT
:
q
=
concat_helper
(
q_nope
,
q_pe
,
dim
=-
1
)
\
if
q_nope
.
shape
[
0
]
<=
1024
:
.
unsqueeze
(
1
)
q
=
concat_helper
(
q_nope
,
q_pe
,
dim
=-
1
)
\
.
unsqueeze
(
1
)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
else
:
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment