Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7b5575fa
Unverified
Commit
7b5575fa
authored
Dec 05, 2025
by
Wentao Ye
Committed by
GitHub
Dec 05, 2025
Browse files
[Bug] Fix vLLM config is not set error (#29999)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
77e44728
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
47 additions
and
27 deletions
+47
-27
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+2
-0
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
...del_executor/layers/fused_moe/fused_moe_modular_method.py
+6
-0
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+30
-27
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+3
-0
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
...el_executor/layers/quantization/utils/flashinfer_utils.py
+6
-0
No files found.
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
7b5575fa
...
@@ -460,6 +460,7 @@ def cutlass_moe_fp8(
...
@@ -460,6 +460,7 @@ def cutlass_moe_fp8(
expert_map
:
torch
.
Tensor
|
None
=
None
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
parallel_config
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
...
@@ -537,6 +538,7 @@ def cutlass_moe_fp8(
...
@@ -537,6 +538,7 @@ def cutlass_moe_fp8(
c_strides2
=
c_strides2
,
c_strides2
=
c_strides2
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
),
),
parallel_config
=
parallel_config
,
)
)
return
fn
(
return
fn
(
...
...
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
View file @
7b5575fa
...
@@ -44,6 +44,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
...
@@ -44,6 +44,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
shared_experts
:
torch
.
nn
.
Module
|
None
,
shared_experts
:
torch
.
nn
.
Module
|
None
,
)
->
"FusedMoEModularMethod"
:
)
->
"FusedMoEModularMethod"
:
parallel_config
=
getattr
(
getattr
(
moe_layer
,
"vllm_config"
,
None
),
"parallel_config"
,
None
,
)
return
FusedMoEModularMethod
(
return
FusedMoEModularMethod
(
old_quant_method
,
old_quant_method
,
FusedMoEModularKernel
(
FusedMoEModularKernel
(
...
@@ -51,6 +56,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
...
@@ -51,6 +56,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
old_quant_method
.
select_gemm_impl
(
prepare_finalize
,
moe_layer
),
old_quant_method
.
select_gemm_impl
(
prepare_finalize
,
moe_layer
),
shared_experts
,
shared_experts
,
getattr
(
moe_layer
,
"shared_experts_stream"
,
None
),
getattr
(
moe_layer
,
"shared_experts_stream"
,
None
),
parallel_config
=
parallel_config
,
),
),
)
)
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
7b5575fa
...
@@ -10,7 +10,7 @@ from typing import final
...
@@ -10,7 +10,7 @@ from typing import final
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
get_current_vllm_config
from
vllm.config
import
ParallelConfig
,
get_current_vllm_config
from
vllm.forward_context
import
get_forward_context
,
is_forward_context_available
from
vllm.forward_context
import
get_forward_context
,
is_forward_context_available
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
...
@@ -716,6 +716,7 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -716,6 +716,7 @@ class FusedMoEModularKernel(torch.nn.Module):
fused_experts
:
FusedMoEPermuteExpertsUnpermute
,
fused_experts
:
FusedMoEPermuteExpertsUnpermute
,
shared_experts
:
torch
.
nn
.
Module
|
None
=
None
,
shared_experts
:
torch
.
nn
.
Module
|
None
=
None
,
shared_experts_stream
:
torch
.
cuda
.
Stream
|
None
=
None
,
shared_experts_stream
:
torch
.
cuda
.
Stream
|
None
=
None
,
parallel_config
:
ParallelConfig
|
None
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
prepare_finalize
=
prepare_finalize
self
.
prepare_finalize
=
prepare_finalize
...
@@ -723,6 +724,14 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -723,6 +724,14 @@ class FusedMoEModularKernel(torch.nn.Module):
self
.
shared_experts
=
shared_experts
self
.
shared_experts
=
shared_experts
self
.
shared_experts_stream
=
shared_experts_stream
self
.
shared_experts_stream
=
shared_experts_stream
# cache whether this worker is using DP+EP
if
parallel_config
is
None
:
parallel_config
=
get_current_vllm_config
().
parallel_config
self
.
is_dp_ep
=
(
parallel_config
.
data_parallel_size
>
1
and
parallel_config
.
enable_expert_parallel
)
self
.
_post_init_setup
()
self
.
_post_init_setup
()
assert
(
assert
(
prepare_finalize
.
activation_format
==
fused_experts
.
activation_formats
[
0
]
prepare_finalize
.
activation_format
==
fused_experts
.
activation_formats
[
0
]
...
@@ -811,13 +820,7 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -811,13 +820,7 @@ class FusedMoEModularKernel(torch.nn.Module):
is_forward_context_available
()
is_forward_context_available
()
and
get_forward_context
().
attn_metadata
is
None
and
get_forward_context
().
attn_metadata
is
None
)
)
if
is_profile_run
and
self
.
fused_experts
.
supports_chunking
():
if
is_profile_run
and
self
.
fused_experts
.
supports_chunking
()
and
self
.
is_dp_ep
:
parallel_config
=
get_current_vllm_config
().
parallel_config
is_dp_ep
=
(
parallel_config
.
data_parallel_size
>
1
and
parallel_config
.
enable_expert_parallel
)
if
is_dp_ep
:
max_workspace_13
,
max_workspace_2
,
max_fused_out_shape
=
(
max_workspace_13
,
max_workspace_2
,
max_fused_out_shape
=
(
self
.
fused_experts
.
workspace_shapes
(
self
.
fused_experts
.
workspace_shapes
(
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
,
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
7b5575fa
...
@@ -1287,6 +1287,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -1287,6 +1287,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
ab_strides2
=
self
.
ab_strides2
,
ab_strides2
=
self
.
ab_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
parallel_config
=
getattr
(
getattr
(
layer
,
"vllm_config"
,
None
),
"parallel_config"
,
None
),
)
)
else
:
else
:
...
...
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
View file @
7b5575fa
...
@@ -247,6 +247,11 @@ def flashinfer_cutlass_moe_fp8(
...
@@ -247,6 +247,11 @@ def flashinfer_cutlass_moe_fp8(
assert
quant_config
is
not
None
assert
quant_config
is
not
None
# Construct modular kernel with block-scale support when requested.
# Construct modular kernel with block-scale support when requested.
parallel_config
=
getattr
(
getattr
(
layer
,
"vllm_config"
,
None
),
"parallel_config"
,
None
,
)
fused_experts
=
mk
.
FusedMoEModularKernel
(
fused_experts
=
mk
.
FusedMoEModularKernel
(
build_flashinfer_fp8_cutlass_moe_prepare_finalize
(
build_flashinfer_fp8_cutlass_moe_prepare_finalize
(
moe
=
moe
,
use_deepseek_fp8_block_scale
=
use_deepseek_fp8_block_scale
moe
=
moe
,
use_deepseek_fp8_block_scale
=
use_deepseek_fp8_block_scale
...
@@ -257,6 +262,7 @@ def flashinfer_cutlass_moe_fp8(
...
@@ -257,6 +262,7 @@ def flashinfer_cutlass_moe_fp8(
out_dtype
=
hidden_states
.
dtype
,
out_dtype
=
hidden_states
.
dtype
,
use_deepseek_fp8_block_scale
=
use_deepseek_fp8_block_scale
,
use_deepseek_fp8_block_scale
=
use_deepseek_fp8_block_scale
,
),
),
parallel_config
=
parallel_config
,
)
)
return
fused_experts
(
return
fused_experts
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment