Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d697dc01
Unverified
Commit
d697dc01
authored
Jan 11, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Jan 11, 2025
Browse files
[Bugfix] Fix RobertaModel loading (#11940)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
a991f7d5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
12 deletions
+67
-12
tests/model_executor/test_model_load_with_params.py
tests/model_executor/test_model_load_with_params.py
+26
-1
tests/models/embedding/language/test_embedding.py
tests/models/embedding/language/test_embedding.py
+1
-0
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+40
-11
No files found.
tests/model_executor/test_model_load_with_params.py
View file @
d697dc01
...
...
@@ -2,7 +2,7 @@ import os
import
pytest
from
vllm.model_executor.layers.pooler
import
PoolingType
from
vllm.model_executor.layers.pooler
import
CLSPool
,
PoolingType
from
vllm.model_executor.models.bert
import
BertEmbeddingModel
from
vllm.model_executor.models.roberta
import
RobertaEmbeddingModel
from
vllm.platforms
import
current_platform
...
...
@@ -92,3 +92,28 @@ def test_roberta_model_loading_with_params(vllm_runner):
# assert output
assert
output
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Xformers backend is not supported on ROCm."
)
def
test_facebook_roberta_model_loading_with_params
(
vllm_runner
):
"""
Test loading roberta-base model with no lm_head.
"""
model_name
=
"FacebookAI/roberta-base"
with
vllm_runner
(
model_name
=
model_name
,
dtype
=
"float16"
,
max_model_len
=
MAX_MODEL_LEN
)
as
model
:
output
=
model
.
encode
(
"Write a short story about a robot that"
" dreams for the first time.
\n
"
)
model_tokenizer
=
model
.
model
.
llm_engine
.
tokenizer
assert
model_tokenizer
.
tokenizer_id
==
model_name
model
=
model
.
model
.
llm_engine
.
model_executor
\
.
driver_worker
.
model_runner
.
model
assert
not
hasattr
(
model
,
"lm_head"
)
assert
isinstance
(
model
,
RobertaEmbeddingModel
)
assert
isinstance
(
model
.
_pooler
,
CLSPool
)
assert
output
tests/models/embedding/language/test_embedding.py
View file @
d697dc01
...
...
@@ -25,6 +25,7 @@ from ..utils import check_embeddings_close
pytest
.
param
(
"ssmits/Qwen2-7B-Instruct-embed-base"
),
pytest
.
param
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
),
pytest
.
param
(
"Alibaba-NLP/gte-Qwen2-7B-instruct"
),
pytest
.
param
(
"sentence-transformers/stsb-roberta-base-v2"
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
...
...
vllm/model_executor/models/roberta.py
View file @
d697dc01
import
itertools
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -20,6 +21,30 @@ from vllm.transformers_utils.config import (
from
.interfaces
import
SupportsCrossEncoding
def
roberta_task_weights_filter
(
all_weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]
)
->
Tuple
[
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]]:
"""
Separate task-specific weights that are applied on top
of the encoder-decoder bert base.
To do so, return two generators over the original iterator.
Also, remove the "roberta." prefix to make it loadable
from vanilla BertModel.
"""
# Copy of a lazy iterator without in-memory overhead so both
# iterators can be iterated upon independently.
all_weights1
,
all_weights2
=
itertools
.
tee
(
all_weights
)
def
encoder_decoder_weights
():
for
name
,
weight
in
all_weights1
:
if
name
.
startswith
(
"roberta."
):
yield
(
name
[
len
(
"roberta."
):],
weight
)
return
encoder_decoder_weights
(),
((
n
,
w
)
for
n
,
w
in
all_weights2
if
not
n
.
startswith
(
"roberta."
))
class
RobertaEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
RobertaConfig
):
...
...
@@ -152,6 +177,18 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
prefix
=
prefix
,
embedding_class
=
RobertaEmbedding
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
weights
=
self
.
hf_to_vllm_mapper
.
apply
(
weights
)
# Separate weights in "roberta"-prefixed and all else (not in memory).
# For use with models like FacebookAI/roberta-base.
bert_weights
,
task_weights
=
roberta_task_weights_filter
(
weights
)
loaded
=
self
.
model
.
load_weights
(
bert_weights
)
if
not
len
(
loaded
):
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
# which use the same architecture, but have no "roberta" prefix.
loaded
=
self
.
model
.
load_weights
(
task_weights
)
assert
len
(
loaded
),
"Unable to load RobertaEmbeddingModel"
class
RobertaForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
):
"""A model that uses Roberta to provide embedding functionalities.
...
...
@@ -181,20 +218,12 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self_weights
=
[]
def
weight_filter
():
for
name
,
weight
in
weights
:
if
name
.
startswith
(
"roberta."
):
yield
(
name
[
len
(
"roberta."
):],
weight
)
else
:
self_weights
.
append
((
name
,
weight
))
self
.
roberta
.
load_weights
(
weight_filter
())
bert_weights
,
task_weights
=
roberta_task_weights_filter
(
weights
)
self
.
roberta
.
load_weights
(
bert_weights
)
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
self
_weights
:
for
name
,
loaded_weight
in
task
_weights
:
if
name
.
startswith
(
"classifier"
):
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment