Unverified Commit ba4f8267 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[BugFix] Fix weight loading for Mixtral with TP (#2208)

parent de60a3fb
...@@ -49,7 +49,6 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -49,7 +49,6 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -94,30 +93,6 @@ class MixtralMLP(nn.Module): ...@@ -94,30 +93,6 @@ class MixtralMLP(nn.Module):
return current_hidden_states return current_hidden_states
class DummyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.w1 = nn.Linear(0, 0, bias=False)
self.w2 = nn.Linear(0, 0, bias=False)
self.w3 = nn.Linear(0, 0, bias=False)
set_weight_attrs(self.w1.weight,
{"weight_loader": self.dummy_weight_loader})
set_weight_attrs(self.w2.weight,
{"weight_loader": self.dummy_weight_loader})
set_weight_attrs(self.w3.weight,
{"weight_loader": self.dummy_weight_loader})
def forward(self, *args, **kwargs) -> None:
raise NotImplementedError()
def dummy_weight_loader(self, *args, **kwargs) -> None: # pylint: disable=unused-argument
# Noop
return
class MixtralMoE(nn.Module): class MixtralMoE(nn.Module):
def __init__( def __init__(
...@@ -147,7 +122,7 @@ class MixtralMoE(nn.Module): ...@@ -147,7 +122,7 @@ class MixtralMoE(nn.Module):
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
linear_method=linear_method) linear_method=linear_method)
if idx in self.expert_indicies else DummyModule() if idx in self.expert_indicies else None
for idx in range(self.num_total_experts) for idx in range(self.num_total_experts)
]) ])
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
...@@ -427,6 +402,10 @@ class MixtralForCausalLM(nn.Module): ...@@ -427,6 +402,10 @@ class MixtralForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Skip experts that are not assigned to this worker.
if ("block_sparse_moe.experts." in name
and name not in params_dict):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment