Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5fe03549
Commit
5fe03549
authored
Feb 04, 2026
by
zhuwenwen
Browse files
[perf] add VLLM_USE_FUSED_FILL_RMS_CAT to use lightop for dpsk mtp fill + rms*2 + cat
parent
b8c7ba0a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
6 deletions
+21
-6
vllm/envs.py
vllm/envs.py
+5
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+4
-0
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+12
-6
No files found.
vllm/envs.py
View file @
5fe03549
...
@@ -288,6 +288,7 @@ if TYPE_CHECKING:
...
@@ -288,6 +288,7 @@ if TYPE_CHECKING:
VLLM_USE_OPT_RESHAPE_AND_CACHE
:
bool
=
False
VLLM_USE_OPT_RESHAPE_AND_CACHE
:
bool
=
False
VLLM_USE_TOPK_RENORM
:
bool
=
False
VLLM_USE_TOPK_RENORM
:
bool
=
False
VLLM_USE_FUSED_RMS_ROPE
:
bool
=
False
VLLM_USE_FUSED_RMS_ROPE
:
bool
=
False
VLLM_USE_FUSED_FILL_RMS_CAT
:
bool
=
False
VLLM_W8A8_BACKEND
:
int
=
3
VLLM_W8A8_BACKEND
:
int
=
3
...
@@ -1819,6 +1820,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1819,6 +1820,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_RMS_ROPE"
:
"VLLM_USE_FUSED_RMS_ROPE"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FUSED_RMS_ROPE"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FUSED_RMS_ROPE"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use lightop for dpsk mtp fill + rms*2 + cat
"VLLM_USE_FUSED_FILL_RMS_CAT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FUSED_FILL_RMS_CAT"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# W8A8 GEMM backend selection for vLLM quantized models.
# W8A8 GEMM backend selection for vLLM quantized models.
# lightop/triton: 1
# lightop/triton: 1
# cutlass: 2 (will remove in the future)
# cutlass: 2 (will remove in the future)
...
...
vllm/model_executor/model_loader/utils.py
View file @
5fe03549
...
@@ -192,6 +192,8 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
...
@@ -192,6 +192,8 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
os
.
environ
[
'VLLM_USE_LIGHTOP'
]
=
'1'
os
.
environ
[
'VLLM_USE_LIGHTOP'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_OPT_CAT"
):
if
not
envs
.
is_set
(
"VLLM_USE_OPT_CAT"
):
os
.
environ
[
'VLLM_USE_OPT_CAT'
]
=
'1'
os
.
environ
[
'VLLM_USE_OPT_CAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_FUSED_FILL_RMS_CAT"
):
os
.
environ
[
'VLLM_USE_FUSED_FILL_RMS_CAT'
]
=
'1'
# if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
# if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
# if not envs.is_set("USE_FUSED_RMS_QUANT"):
# if not envs.is_set("USE_FUSED_RMS_QUANT"):
# os.environ['USE_FUSED_RMS_QUANT'] = '1'
# os.environ['USE_FUSED_RMS_QUANT'] = '1'
...
@@ -224,6 +226,8 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
...
@@ -224,6 +226,8 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
os
.
environ
[
'VLLM_USE_LIGHTOP'
]
=
'1'
os
.
environ
[
'VLLM_USE_LIGHTOP'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_OPT_CAT"
):
if
not
envs
.
is_set
(
"VLLM_USE_OPT_CAT"
):
os
.
environ
[
'VLLM_USE_OPT_CAT'
]
=
'1'
os
.
environ
[
'VLLM_USE_OPT_CAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_FUSED_FILL_RMS_CAT"
):
os
.
environ
[
'VLLM_USE_FUSED_FILL_RMS_CAT'
]
=
'1'
# if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
# if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
# if not envs.is_set("USE_FUSED_RMS_QUANT"):
# if not envs.is_set("USE_FUSED_RMS_QUANT"):
# os.environ['USE_FUSED_RMS_QUANT'] = '1'
# os.environ['USE_FUSED_RMS_QUANT'] = '1'
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
5fe03549
...
@@ -39,6 +39,7 @@ from .deepseek_v2 import (
...
@@ -39,6 +39,7 @@ from .deepseek_v2 import (
from
.utils
import
maybe_prefix
from
.utils
import
maybe_prefix
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.blockwise_int8
import
BlockInt8Config
from
vllm.model_executor.layers.quantization.blockwise_int8
import
BlockInt8Config
import
vllm.envs
as
envs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -109,13 +110,18 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
...
@@ -109,13 +110,18 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
inputs_embeds
is
not
None
assert
inputs_embeds
is
not
None
# masking inputs at position 0, as not needed by MTP
# masking inputs at position 0, as not needed by MTP
inputs_embeds
=
torch
.
where
(
positions
.
unsqueeze
(
-
1
)
==
0
,
0
,
inputs_embeds
)
if
envs
.
VLLM_USE_FUSED_FILL_RMS_CAT
:
inputs_embeds
=
self
.
enorm
(
inputs_embeds
)
hidden_states_fuse
=
torch
.
empty
(
inputs_embeds
.
shape
[
0
],
inputs_embeds
.
shape
[
1
]
*
2
,
device
=
inputs_embeds
.
device
,
dtype
=
inputs_embeds
.
dtype
)
previous_hidden_states
=
self
.
hnorm
(
previous_hidden_states
)
torch
.
ops
.
vllm
.
fuse_fill_rms_x2_concat
(
hidden_states_fuse
,
positions
,
inputs_embeds
,
previous_hidden_states
,
self
.
enorm
.
weight
,
self
.
hnorm
.
weight
,
self
.
enorm
.
variance_epsilon
)
hidden_states
=
self
.
eh_proj
(
hidden_states_fuse
)
else
:
inputs_embeds
=
torch
.
where
(
positions
.
unsqueeze
(
-
1
)
==
0
,
0
,
inputs_embeds
)
inputs_embeds
=
self
.
enorm
(
inputs_embeds
)
previous_hidden_states
=
self
.
hnorm
(
previous_hidden_states
)
hidden_states
=
self
.
eh_proj
(
hidden_states
=
self
.
eh_proj
(
torch
.
cat
([
inputs_embeds
,
previous_hidden_states
],
dim
=-
1
)
torch
.
cat
([
inputs_embeds
,
previous_hidden_states
],
dim
=-
1
)
)
)
hidden_states
,
residual
=
self
.
mtp_block
(
hidden_states
,
residual
=
self
.
mtp_block
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
None
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment