Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e80dcabe
Commit
e80dcabe
authored
Dec 23, 2025
by
zhuwenwen
Browse files
add VLLM_USE_FUSED_FILL_RMS_CAT for dpsk mtp fill + rms*2 + cat
update VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT impl
parent
4f9947e6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
32 additions
and
36 deletions
+32
-36
vllm/attention/layer.py
vllm/attention/layer.py
+13
-24
vllm/envs.py
vllm/envs.py
+5
-1
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+0
-4
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+13
-6
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+1
-1
No files found.
vllm/attention/layer.py
View file @
e80dcabe
...
...
@@ -553,31 +553,20 @@ def unified_attention_with_output(
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
def
unified_attention_with_output_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
def
unified_attention_with_output_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
cos_sin_cache
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
return
else
:
def
unified_attention_with_output_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
cos_sin_cache
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
return
return
direct_register_custom_op
(
...
...
vllm/envs.py
View file @
e80dcabe
...
...
@@ -196,6 +196,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
bool
=
False
VLLM_USE_FUSED_RMS_ROPE
:
bool
=
False
VLLM_USE_MARLIN_W16A16_MOE
:
bool
=
False
VLLM_USE_FUSED_FILL_RMS_CAT
:
bool
=
False
def
get_default_cache_root
():
return
os
.
getenv
(
...
...
@@ -1275,7 +1276,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MARLIN_W16A16_MOE"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MARLIN_W16A16_MOE"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use lightop for dpsk mtp fill + rms*2 + cat
"VLLM_USE_FUSED_FILL_RMS_CAT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FUSED_FILL_RMS_CAT"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/model_loader/utils.py
View file @
e80dcabe
...
...
@@ -253,8 +253,6 @@ def get_model_architecture(
os
.
environ
[
'VLLM_USE_OPT_CAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_CAT_MLA"
):
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
# if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
# os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
if
not
envs
.
is_set
(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"
):
os
.
environ
[
'VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
):
...
...
@@ -298,8 +296,6 @@ def get_model_architecture(
os
.
environ
[
'VLLM_USE_OPT_CAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_CAT_MLA"
):
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
# if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
# os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
if
not
envs
.
is_set
(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"
):
os
.
environ
[
'VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
):
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
e80dcabe
...
...
@@ -28,6 +28,9 @@ from .interfaces import SupportsPP
from
.utils
import
maybe_prefix
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.blockwise_int8
import
BlockInt8Config
import
vllm.envs
as
envs
from
lightop
import
fuse_fill_rms_x2_concat
class
SharedHead
(
nn
.
Module
):
...
...
@@ -84,12 +87,16 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
assert
inputs_embeds
is
not
None
# masking inputs at position 0, as not needed by MTP
inputs_embeds
[
positions
==
0
]
=
0
inputs_embeds
=
self
.
enorm
(
inputs_embeds
)
previous_hidden_states
=
self
.
hnorm
(
previous_hidden_states
)
hidden_states
=
self
.
eh_proj
(
torch
.
cat
([
inputs_embeds
,
previous_hidden_states
],
dim
=-
1
))
if
envs
.
VLLM_USE_FUSED_FILL_RMS_CAT
:
hidden_states_fuse
=
torch
.
empty
(
hidden_states
.
shape
[
0
],
hidden_states
.
shaope
[
1
]
*
2
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
fuse_fill_rms_x2_concat
(
hidden_states_fuse
,
positions
,
inputs_embeds
,
previous_hidden_states
,
self
.
enorm
.
weight
,
self
.
hnorm
.
weight
,
self
.
enorm
.
variance_epsilon
)
hidden_states
=
self
.
eh_proj
(
hidden_states_fuse
)
else
:
inputs_embeds
[
positions
==
0
]
=
0
inputs_embeds
=
self
.
enorm
(
inputs_embeds
)
previous_hidden_states
=
self
.
hnorm
(
previous_hidden_states
)
hidden_states
=
self
.
eh_proj
(
torch
.
cat
([
inputs_embeds
,
previous_hidden_states
],
dim
=-
1
))
hidden_states
,
residual
=
self
.
mtp_block
(
positions
=
positions
,
hidden_states
=
hidden_states
,
...
...
vllm/v1/attention/backends/mla/common.py
View file @
e80dcabe
...
...
@@ -217,6 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
lightop
import
fused_rms_norm_rope_contiguous
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
...
@@ -1163,7 +1164,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str
=
"bf16"
else
:
kv_cache_dtype_str
=
self
.
kv_cache_dtype
from
lightop
import
fused_rms_norm_rope_contiguous
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
q
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment