Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9b4b1503
Unverified
Commit
9b4b1503
authored
Nov 28, 2024
by
Cyrus Leung
Committed by
GitHub
Nov 27, 2024
Browse files
[Bugfix] Ignore `lm_head` when loading embedding models (#10719)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
197b4484
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
8 additions
and
0 deletions
+8
-0
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+2
-0
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+2
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+2
-0
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+2
-0
No files found.
vllm/model_executor/models/bert.py
View file @
9b4b1503
...
...
@@ -443,6 +443,8 @@ class BertEmbeddingModel(nn.Module):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
weights
=
hf_to_vllm_mapper
.
apply
(
weights
)
weights
=
((
name
,
data
)
for
name
,
data
in
weights
if
not
name
.
startswith
(
"lm_head."
))
self
.
model
.
load_weights
(
weights
)
def
_build_model
(
self
,
...
...
vllm/model_executor/models/gemma2.py
View file @
9b4b1503
...
...
@@ -504,4 +504,6 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
weights
=
hf_to_vllm_mapper
.
apply
(
weights
)
weights
=
((
name
,
data
)
for
name
,
data
in
weights
if
not
name
.
startswith
(
"lm_head."
))
self
.
model
.
load_weights
(
weights
)
vllm/model_executor/models/llama.py
View file @
9b4b1503
...
...
@@ -689,6 +689,8 @@ class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
weights
=
hf_to_vllm_mapper
.
apply
(
weights
)
weights
=
((
name
,
data
)
for
name
,
data
in
weights
if
not
name
.
startswith
(
"lm_head."
))
self
.
model
.
load_weights
(
weights
)
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
...
...
vllm/model_executor/models/qwen2.py
View file @
9b4b1503
...
...
@@ -580,4 +580,6 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
weights
=
hf_to_vllm_mapper
.
apply
(
weights
)
weights
=
((
name
,
data
)
for
name
,
data
in
weights
if
not
name
.
startswith
(
"lm_head."
))
self
.
model
.
load_weights
(
weights
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment