Unverified Commit 5bff999d authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Bugfix] Add method to swap quant_method on FusedMoE to fix LoRA issues (#34453)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent bb85929a
...@@ -338,8 +338,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -338,8 +338,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
fused_experts.moe_sum = moe_sum_decorator( fused_experts.moe_sum = moe_sum_decorator(
self.base_layer, fused_experts.moe_sum self.base_layer, fused_experts.moe_sum
) )
self.base_layer.quant_method = FusedMoEModularMethod( # TODO(bnell): find a less intrusive way to handle this.
self.base_layer.quant_method, m_fused_moe_fn self.base_layer._replace_quant_method(
FusedMoEModularMethod(self.base_layer.quant_method, m_fused_moe_fn)
) )
def _create_lora_a_weights( def _create_lora_a_weights(
......
...@@ -655,6 +655,16 @@ class FusedMoE(CustomOp): ...@@ -655,6 +655,16 @@ class FusedMoE(CustomOp):
enable_dbo=self.vllm_config.parallel_config.enable_dbo, enable_dbo=self.vllm_config.parallel_config.enable_dbo,
) )
# TODO(bnell): This method is provided as a hook so vllm/lora/layers/fused_moe.py
# can safely swap out the quant_method. We should figure out a less
# intrusive way to do this.
def _replace_quant_method(self, mk: FusedMoEMethodBase):
self.quant_method = mk
# We need to force reconstruction of runner because we're swapping out
# the quant_method with a FusedMoEModularMethod. This logic can go
# away once the FusedMoEModularMethod is eliminated.
self.runner = self._init_runner()
# Note: maybe_init_modular_kernel should only be called by # Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model. # prepare_communication_buffer_for_model.
# This is called after all weight loading and post-processing, so it # This is called after all weight loading and post-processing, so it
...@@ -676,17 +686,15 @@ class FusedMoE(CustomOp): ...@@ -676,17 +686,15 @@ class FusedMoE(CustomOp):
logger.debug( logger.debug(
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self) "%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
) )
self.quant_method = FusedMoEModularMethod.make( self._replace_quant_method(
FusedMoEModularMethod.make(
self, self,
self.quant_method, self.quant_method,
prepare_finalize, prepare_finalize,
self.shared_experts, self.shared_experts,
inplace=not self.moe_config.disable_inplace, inplace=not self.moe_config.disable_inplace,
) )
# We need to force reconstruction of runner because we're swapping out )
# the quant_method with a FusedMoEModularMethod. This logic can go
# away once the FusedMoEModularMethod is eliminated.
self.runner = self._init_runner()
@property @property
def shared_experts(self) -> torch.nn.Module | None: def shared_experts(self) -> torch.nn.Module | None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment