Unverified Commit 5bff999d authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Bugfix] Add method to swap quant_method on FusedMoE to fix LoRA issues (#34453)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent bb85929a
......@@ -338,8 +338,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
fused_experts.moe_sum = moe_sum_decorator(
self.base_layer, fused_experts.moe_sum
)
self.base_layer.quant_method = FusedMoEModularMethod(
self.base_layer.quant_method, m_fused_moe_fn
# TODO(bnell): find a less intrusive way to handle this.
self.base_layer._replace_quant_method(
FusedMoEModularMethod(self.base_layer.quant_method, m_fused_moe_fn)
)
def _create_lora_a_weights(
......
......@@ -655,6 +655,16 @@ class FusedMoE(CustomOp):
enable_dbo=self.vllm_config.parallel_config.enable_dbo,
)
# TODO(bnell): This method is provided as a hook so vllm/lora/layers/fused_moe.py
# can safely swap out the quant_method. We should figure out a less
# intrusive way to do this.
def _replace_quant_method(self, mk: FusedMoEMethodBase):
self.quant_method = mk
# We need to force reconstruction of runner because we're swapping out
# the quant_method with a FusedMoEModularMethod. This logic can go
# away once the FusedMoEModularMethod is eliminated.
self.runner = self._init_runner()
# Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model.
# This is called after all weight loading and post-processing, so it
......@@ -676,17 +686,15 @@ class FusedMoE(CustomOp):
logger.debug(
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
)
self.quant_method = FusedMoEModularMethod.make(
self,
self.quant_method,
prepare_finalize,
self.shared_experts,
inplace=not self.moe_config.disable_inplace,
self._replace_quant_method(
FusedMoEModularMethod.make(
self,
self.quant_method,
prepare_finalize,
self.shared_experts,
inplace=not self.moe_config.disable_inplace,
)
)
# We need to force reconstruction of runner because we're swapping out
# the quant_method with a FusedMoEModularMethod. This logic can go
# away once the FusedMoEModularMethod is eliminated.
self.runner = self._init_runner()
@property
def shared_experts(self) -> torch.nn.Module | None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment