Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b0dfa004
Commit
b0dfa004
authored
Sep 10, 2025
by
zhuwenwen
Browse files
add VLLM_USE_TRITON_CAT to opt torch cat
parent
a99300bd
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
191 additions
and
5 deletions
+191
-5
vllm/envs.py
vllm/envs.py
+6
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+11
-3
vllm/v1/attention/backends/mla/concatv3Tritonfinal.py
vllm/v1/attention/backends/mla/concatv3Tritonfinal.py
+166
-0
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+8
-2
No files found.
vllm/envs.py
View file @
b0dfa004
...
...
@@ -195,6 +195,7 @@ if TYPE_CHECKING:
VLLM_USE_APEX_RN
:
bool
=
False
VLLM_USE_GLOBAL_CACHE13
:
bool
=
False
VLLM_USE_LIGHT_OP
:
bool
=
False
VLLM_USE_TRITON_CAT
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -1342,6 +1343,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_LIGHT_OP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHT_OP"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use global cache for moe
"VLLM_USE_TRITON_CAT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_CAT"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/v1/attention/backends/mla/common.py
View file @
b0dfa004
...
...
@@ -224,6 +224,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
split_decodes_and_prefills
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.attention.backends.mla.concatv3Tritonfinal
import
concat_helper
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
...
@@ -1217,8 +1218,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope
,
v
=
kv_nope
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
if
envs
.
VLLM_USE_TRITON_CAT
:
k
=
concat_helper
(
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
)),
dim
=-
1
)
else
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
attn_output
,
attn_softmax_lse
=
self
.
_run_prefill_context_chunk
(
prefill
=
prefill_metadata
,
...
...
@@ -1267,7 +1272,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope
,
v
=
kv_nope
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
if
envs
.
VLLM_USE_TRITON_CAT
:
k
=
concat_helper
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
else
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
output
=
self
.
_run_prefill_new_tokens
(
prefill
=
attn_metadata
.
prefill
,
...
...
vllm/v1/attention/backends/mla/concatv3Tritonfinal.py
0 → 100644
View file @
b0dfa004
import
triton
import
triton.language
as
tl
import
torch
from
functools
import
reduce
import
pytest
import
torch
import
math
@
pytest
.
mark
.
parametrize
(
"shape_pair,dim"
,
[
(((
4
,
8
,
512
),
(
4
,
8
,
64
)),
2
),
(((
8
,
8
,
512
),
(
8
,
8
,
64
)),
2
),
(((
16
,
8
,
512
),
(
16
,
8
,
64
)),
2
),
(((
32
,
8
,
512
),
(
32
,
8
,
64
)),
2
),
(((
64
,
8
,
512
),
(
64
,
8
,
64
)),
2
),
(((
128
,
8
,
512
),
(
128
,
8
,
64
)),
2
),
(((
256
,
8
,
512
),
(
256
,
8
,
64
)),
2
),
(((
512
,
8
,
512
),
(
512
,
8
,
64
)),
2
),
(((
672
,
8
,
512
),
(
672
,
8
,
64
)),
2
),
(((
768
,
8
,
512
),
(
768
,
8
,
64
)),
2
),
(((
896
,
8
,
512
),
(
896
,
8
,
64
)),
2
),
(((
1024
,
8
,
512
),
(
1024
,
8
,
64
)),
2
),
(((
4
,
16
,
512
),
(
4
,
16
,
64
)),
2
),
(((
8
,
16
,
512
),
(
8
,
16
,
64
)),
2
),
(((
16
,
16
,
512
),
(
16
,
16
,
64
)),
2
),
(((
32
,
16
,
512
),
(
32
,
16
,
64
)),
2
),
(((
64
,
16
,
512
),
(
64
,
16
,
64
)),
2
),
(((
128
,
16
,
512
),
(
128
,
16
,
64
)),
2
),
(((
256
,
16
,
512
),
(
256
,
16
,
64
)),
2
),
(((
512
,
16
,
512
),
(
512
,
16
,
64
)),
2
),
(((
672
,
16
,
512
),
(
672
,
16
,
64
)),
2
),
(((
768
,
16
,
512
),
(
768
,
16
,
64
)),
2
),
(((
896
,
16
,
512
),
(
896
,
16
,
64
)),
2
),
(((
1024
,
16
,
512
),
(
1024
,
16
,
64
)),
2
),
(((
4
,
32
,
512
),
(
4
,
32
,
64
)),
2
),
(((
8
,
32
,
512
),
(
8
,
32
,
64
)),
2
),
(((
16
,
32
,
512
),
(
16
,
32
,
64
)),
2
),
(((
32
,
32
,
512
),
(
32
,
32
,
64
)),
2
),
(((
64
,
32
,
512
),
(
64
,
32
,
64
)),
2
),
(((
128
,
32
,
512
),
(
128
,
32
,
64
)),
2
),
(((
256
,
32
,
512
),
(
256
,
32
,
64
)),
2
),
(((
512
,
32
,
512
),
(
512
,
32
,
64
)),
2
),
(((
672
,
32
,
512
),
(
672
,
32
,
64
)),
2
),
(((
768
,
32
,
512
),
(
768
,
32
,
64
)),
2
),
(((
896
,
32
,
512
),
(
896
,
32
,
64
)),
2
),
(((
1024
,
32
,
512
),
(
1024
,
32
,
64
)),
2
),
])
def
test_concat_Acc
(
shape_pair
,
dim
):
torch
.
manual_seed
(
1
)
shape1
,
shape2
=
shape_pair
x
=
torch
.
randn
(
*
shape1
,
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
randn
(
*
shape2
,
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
expected
=
torch
.
cat
([
x
,
y
],
dim
=
dim
)
result
=
concat_helper
(
x
,
y
,
dim
=
dim
)
assert
torch
.
allclose
(
result
,
expected
,
rtol
=
1e-5
,
atol
=
1e-5
),
"Mismatch"
@
triton
.
jit
def
concat_kernel
(
A_ptr
,
B_ptr
,
C_ptr
,
A_section_numel
,
B_section_numel
,
C_section_numel
,
Per_block
,
section_num
,
BLOCK_SIZE
:
tl
.
constexpr
):
block_idx
=
tl
.
program_id
(
0
)
for
sub_section_index
in
range
(
Per_block
):
sub_offset
=
block_idx
*
Per_block
+
sub_section_index
if
sub_offset
<=
section_num
-
1
:
C_ptr_block_start
=
C_ptr
+
sub_offset
*
C_section_numel
A_ptr_block_start
=
A_ptr
+
sub_offset
*
A_section_numel
B_ptr_block_start
=
B_ptr
+
sub_offset
*
B_section_numel
for
offset
in
range
(
0
,
A_section_numel
,
BLOCK_SIZE
):
offset_idx
=
offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offset_idx
<
A_section_numel
val_from_A
=
tl
.
load
(
A_ptr_block_start
+
offset_idx
,
mask
=
mask
)
tl
.
store
(
C_ptr_block_start
+
offset_idx
,
val_from_A
,
mask
=
mask
)
for
offset
in
range
(
0
,
B_section_numel
,
BLOCK_SIZE
):
offset_idx
=
offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offset_idx
<
B_section_numel
val_from_B
=
tl
.
load
(
B_ptr_block_start
+
offset_idx
,
mask
=
mask
)
tl
.
store
(
C_ptr_block_start
+
A_section_numel
+
offset_idx
,
val_from_B
,
mask
=
mask
)
def
concat_helper
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
):
A
=
A
.
contiguous
()
B
=
B
.
contiguous
()
output_shape
=
list
(
A
.
shape
)
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
C
=
torch
.
empty
(
output_shape
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
if
dim
!=
0
:
block_num
=
reduce
(
lambda
x
,
y
:
x
*
y
,
output_shape
[:
dim
])
unit_offset_A
,
unit_offset_B
,
unit_offset_C
=
A
.
stride
(
dim
-
1
),
B
.
stride
(
dim
-
1
),
C
.
stride
(
dim
-
1
)
Per_block
=
1
if
(
A
.
shape
[
1
]
==
8
and
A
.
shape
[
0
]
>
128
)
or
(
A
.
shape
[
1
]
==
16
and
A
.
shape
[
0
]
>
96
)
or
(
A
.
shape
[
1
]
==
32
and
A
.
shape
[
0
]
>
64
):
Per_block
=
2
num_blocks
=
math
.
ceil
(
block_num
/
Per_block
)
concat_kernel
[(
num_blocks
,)](
A
,
B
,
C
,
unit_offset_A
,
unit_offset_B
,
unit_offset_C
,
Per_block
,
block_num
,
BLOCK_SIZE
=
1024
)
return
C
assert
False
,
"not support"
configs
=
[]
configs
.
append
(
triton
.
testing
.
Benchmark
(
x_names
=
[
'size'
],
x_vals
=
[
4
,
8
,
16
,
32
,
64
,
96
,
128
,
256
,
512
,
768
,
1024
],
x_log
=
True
,
line_arg
=
'provider'
,
line_vals
=
[
'triton'
,
'torch'
],
line_names
=
[
'Triton'
,
'Torch'
],
styles
=
[(
'blue'
,
'-'
),
(
'green'
,
'-'
)],
ylabel
=
's'
,
plot_name
=
'concat-dim2'
,
args
=
{
"dim"
:
2
},
),
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark
(
size
,
provider
,
dim
):
x
=
torch
.
rand
([
size
,
8
,
512
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
rand
([
size
,
8
,
64
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark_16
(
size
,
provider
,
dim
):
x
=
torch
.
rand
([
size
,
16
,
512
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
rand
([
size
,
16
,
64
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark_32
(
size
,
provider
,
dim
):
x
=
torch
.
rand
([
size
,
32
,
512
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
rand
([
size
,
32
,
64
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
if
__name__
==
'__main__'
:
benchmark
.
run
(
save_path
=
"./triton_test_8"
,
print_data
=
True
)
benchmark_16
.
run
(
save_path
=
"./triton_test_16"
,
print_data
=
True
)
benchmark_32
.
run
(
save_path
=
"./triton_test_32"
,
print_data
=
True
)
\ No newline at end of file
vllm/v1/attention/backends/mla/flashmla.py
View file @
b0dfa004
...
...
@@ -19,6 +19,8 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonMetadataBuilder
)
from
vllm.v1.attention.backends.utils
import
AttentionCGSupport
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm
import
envs
from
vllm.v1.attention.backends.mla.concatv3Tritonfinal
import
concat_helper
logger
=
init_logger
(
__name__
)
...
...
@@ -176,8 +178,12 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
if
envs
.
VLLM_USE_TRITON_CAT
:
q
=
concat_helper
(
q_nope
,
q_pe
,
dim
=-
1
)
\
.
unsqueeze
(
1
)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
o
,
_
=
flash_mla_with_kvcache
(
q
=
q
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment