Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5bff999d
Unverified
Commit
5bff999d
authored
Feb 15, 2026
by
bnellnm
Committed by
GitHub
Feb 15, 2026
Browse files
[Bugfix] Add method to swap quant_method on FusedMoE to fix LoRA issues (#34453)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
bb85929a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
12 deletions
+21
-12
vllm/lora/layers/fused_moe.py
vllm/lora/layers/fused_moe.py
+3
-2
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+18
-10
No files found.
vllm/lora/layers/fused_moe.py
View file @
5bff999d
...
...
@@ -338,8 +338,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
fused_experts
.
moe_sum
=
moe_sum_decorator
(
self
.
base_layer
,
fused_experts
.
moe_sum
)
self
.
base_layer
.
quant_method
=
FusedMoEModularMethod
(
self
.
base_layer
.
quant_method
,
m_fused_moe_fn
# TODO(bnell): find a less intrusive way to handle this.
self
.
base_layer
.
_replace_quant_method
(
FusedMoEModularMethod
(
self
.
base_layer
.
quant_method
,
m_fused_moe_fn
)
)
def
_create_lora_a_weights
(
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
5bff999d
...
...
@@ -655,6 +655,16 @@ class FusedMoE(CustomOp):
enable_dbo
=
self
.
vllm_config
.
parallel_config
.
enable_dbo
,
)
# TODO(bnell): This method is provided as a hook so vllm/lora/layers/fused_moe.py
# can safely swap out the quant_method. We should figure out a less
# intrusive way to do this.
def
_replace_quant_method
(
self
,
mk
:
FusedMoEMethodBase
):
self
.
quant_method
=
mk
# We need to force reconstruction of runner because we're swapping out
# the quant_method with a FusedMoEModularMethod. This logic can go
# away once the FusedMoEModularMethod is eliminated.
self
.
runner
=
self
.
_init_runner
()
# Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model.
# This is called after all weight loading and post-processing, so it
...
...
@@ -676,17 +686,15 @@ class FusedMoE(CustomOp):
logger
.
debug
(
"%s for %s(%s)"
,
prepare_finalize
.
__class__
.
__name__
,
self
,
id
(
self
)
)
self
.
quant_method
=
FusedMoEModularMethod
.
make
(
self
,
self
.
quant_method
,
prepare_finalize
,
self
.
shared_experts
,
inplace
=
not
self
.
moe_config
.
disable_inplace
,
self
.
_replace_quant_method
(
FusedMoEModularMethod
.
make
(
self
,
self
.
quant_method
,
prepare_finalize
,
self
.
shared_experts
,
inplace
=
not
self
.
moe_config
.
disable_inplace
,
)
)
# We need to force reconstruction of runner because we're swapping out
# the quant_method with a FusedMoEModularMethod. This logic can go
# away once the FusedMoEModularMethod is eliminated.
self
.
runner
=
self
.
_init_runner
()
@
property
def
shared_experts
(
self
)
->
torch
.
nn
.
Module
|
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment