"vllm/vscode:/vscode.git/clone" did not exist on "d6953beb91da4e9c99be4c0a1304a2d24189535c"
Unverified Commit 9b4b1503 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Ignore `lm_head` when loading embedding models (#10719)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 197b4484
...@@ -443,6 +443,8 @@ class BertEmbeddingModel(nn.Module): ...@@ -443,6 +443,8 @@ class BertEmbeddingModel(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights) weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights) self.model.load_weights(weights)
def _build_model(self, def _build_model(self,
......
...@@ -504,4 +504,6 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP): ...@@ -504,4 +504,6 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights) weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights) self.model.load_weights(weights)
...@@ -689,6 +689,8 @@ class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): ...@@ -689,6 +689,8 @@ class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights) weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights) self.model.load_weights(weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None: def load_kv_cache_scales(self, quantization_param_path: str) -> None:
......
...@@ -580,4 +580,6 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): ...@@ -580,4 +580,6 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights) weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights) self.model.load_weights(weights)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment