Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e4b87133
Unverified
Commit
e4b87133
authored
May 11, 2025
by
wang.yuqi
Committed by
GitHub
May 11, 2025
Browse files
[New Model]: nomic-embed-text-v2-moe (#17785)
parent
06c0922a
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
899 additions
and
364 deletions
+899
-364
docs/source/models/supported_models.md
docs/source/models/supported_models.md
+17
-3
tests/models/language/pooling/mteb_utils.py
tests/models/language/pooling/mteb_utils.py
+111
-0
tests/models/language/pooling/test_nomic.py
tests/models/language/pooling/test_nomic.py
+47
-0
tests/models/language/pooling/test_snowflake_arctic_embed.py
tests/models/language/pooling/test_snowflake_arctic_embed.py
+22
-43
tests/models/utils.py
tests/models/utils.py
+2
-1
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+29
-238
vllm/model_executor/models/bert_with_rope.py
vllm/model_executor/models/bert_with_rope.py
+652
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+2
-2
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+17
-77
No files found.
docs/source/models/supported_models.md
View file @
e4b87133
...
@@ -622,7 +622,7 @@ Specified using `--task embed`.
...
@@ -622,7 +622,7 @@ Specified using `--task embed`.
*
[
PP
](
#distributed-serving
)
*
[
PP
](
#distributed-serving
)
-
*
`BertModel`
-
*
`BertModel`
*
BERT-based
*
BERT-based
*
`BAAI/bge-base-en-v1.5`
, etc.
*
`BAAI/bge-base-en-v1.5`
,
`Snowflake/snowflake-arctic-embed-xs`
,
etc.
*
*
*
*
-
*
`Gemma2Model`
-
*
`Gemma2Model`
...
@@ -635,6 +635,16 @@ Specified using `--task embed`.
...
@@ -635,6 +635,16 @@ Specified using `--task embed`.
*
`parasail-ai/GritLM-7B-vllm`
.
*
`parasail-ai/GritLM-7B-vllm`
.
*
✅︎
*
✅︎
*
✅︎
*
✅︎
-
*
`GteModel`
*
GteModel
*
`Snowflake/snowflake-arctic-embed-m-v2.0`
.
*
*
︎
-
*
`NomicBertModel`
*
NomicBertModel
*
`nomic-ai/nomic-embed-text-v1`
,
`nomic-ai/nomic-embed-text-v2-moe`
,
`Snowflake/snowflake-arctic-embed-m-long`
, etc.
*
︎
*
︎
-
*
`LlamaModel`
,
`LlamaForCausalLM`
,
`MistralModel`
, etc.
-
*
`LlamaModel`
,
`LlamaForCausalLM`
,
`MistralModel`
, etc.
*
Llama-based
*
Llama-based
*
`intfloat/e5-mistral-7b-instruct`
, etc.
*
`intfloat/e5-mistral-7b-instruct`
, etc.
...
@@ -647,12 +657,12 @@ Specified using `--task embed`.
...
@@ -647,12 +657,12 @@ Specified using `--task embed`.
*
✅︎
*
✅︎
-
*
`RobertaModel`
,
`RobertaForMaskedLM`
-
*
`RobertaModel`
,
`RobertaForMaskedLM`
*
RoBERTa-based
*
RoBERTa-based
*
`sentence-transformers/all-roberta-large-v1`
,
`sentence-transformers/all-roberta-large-v1`
,
etc.
*
`sentence-transformers/all-roberta-large-v1`
, etc.
*
*
*
*
-
*
`XLMRobertaModel`
-
*
`XLMRobertaModel`
*
XLM-RoBERTa-based
*
XLM-RoBERTa-based
*
`intfloat/multilingual-e5-large`
,
`jinaai/jina-reranker-v2-base-multilingual`
, etc.
*
`intfloat/multilingual-e5-large`
,
`jinaai/jina-reranker-v2-base-multilingual`
,
`Snowflake/snowflake-arctic-embed-l-v2.0`
,
`jinaai/jina-embeddings-v3`
(see note),
etc.
*
*
*
*
:::
:::
...
@@ -670,6 +680,10 @@ For both the 1.5B and 7B variants, you also need to enable `--trust-remote-code`
...
@@ -670,6 +680,10 @@ For both the 1.5B and 7B variants, you also need to enable `--trust-remote-code`
See
[
relevant issue on HF Transformers
](
https://github.com/huggingface/transformers/issues/34882
)
.
See
[
relevant issue on HF Transformers
](
https://github.com/huggingface/transformers/issues/34882
)
.
:::
:::
:::{note}
`jinaai/jina-embeddings-v3`
supports multiple tasks through lora, while vllm temporarily only supports text-matching tasks by merging lora weights.
:::
If your model is not in the above list, we will try to automatically convert the model using
If your model is not in the above list, we will try to automatically convert the model using
{func}
`~vllm.model_executor.models.adapters.as_embedding_model`
. By default, the embeddings
{func}
`~vllm.model_executor.models.adapters.as_embedding_model`
. By default, the embeddings
of the whole prompt are extracted from the normalized hidden state corresponding to the last token.
of the whole prompt are extracted from the normalized hidden state corresponding to the last token.
...
...
tests/models/language/pooling/mteb_utils.py
0 → 100644
View file @
e4b87133
# SPDX-License-Identifier: Apache-2.0
import
math
from
collections.abc
import
Sequence
import
mteb
import
numpy
as
np
import
pytest
from
tests.models.utils
import
EmbedModelInfo
# Most models on the STS12 task (See #17175):
# - Model implementation and minor changes in tensor dtype
# results in differences less than 1e-4
# - Different model results in differences more than 1e-3
# 1e-4 is a good tolerance threshold
MTEB_EMBED_TASKS
=
[
"STS12"
]
MTEB_EMBED_TOL
=
1e-4
class
VllmMtebEncoder
(
mteb
.
Encoder
):
def
__init__
(
self
,
vllm_model
):
super
().
__init__
()
self
.
model
=
vllm_model
self
.
rng
=
np
.
random
.
default_rng
(
seed
=
42
)
def
encode
(
self
,
sentences
:
Sequence
[
str
],
*
args
,
**
kwargs
,
)
->
np
.
ndarray
:
# Hoping to discover potential scheduling
# issues by randomizing the order.
r
=
self
.
rng
.
permutation
(
len
(
sentences
))
sentences
=
[
sentences
[
i
]
for
i
in
r
]
outputs
=
self
.
model
.
encode
(
sentences
,
use_tqdm
=
False
)
embeds
=
np
.
array
(
outputs
)
embeds
=
embeds
[
np
.
argsort
(
r
)]
return
embeds
class
OpenAIClientMtebEncoder
(
mteb
.
Encoder
):
def
__init__
(
self
,
model_name
:
str
,
client
):
super
().
__init__
()
self
.
model_name
=
model_name
self
.
client
=
client
self
.
rng
=
np
.
random
.
default_rng
(
seed
=
42
)
def
encode
(
self
,
sentences
:
Sequence
[
str
],
*
args
,
**
kwargs
)
->
np
.
ndarray
:
# Hoping to discover potential scheduling
# issues by randomizing the order.
r
=
self
.
rng
.
permutation
(
len
(
sentences
))
sentences
=
[
sentences
[
i
]
for
i
in
r
]
embeddings
=
self
.
client
.
embeddings
.
create
(
model
=
self
.
model_name
,
input
=
sentences
)
outputs
=
[
d
.
embedding
for
d
in
embeddings
.
data
]
embeds
=
np
.
array
(
outputs
)
embeds
=
embeds
[
np
.
argsort
(
r
)]
return
embeds
def
run_mteb_embed_task
(
encoder
,
tasks
):
tasks
=
mteb
.
get_tasks
(
tasks
=
tasks
)
evaluation
=
mteb
.
MTEB
(
tasks
=
tasks
)
results
=
evaluation
.
run
(
encoder
,
verbosity
=
0
,
output_folder
=
None
)
main_score
=
results
[
0
].
scores
[
"test"
][
0
][
"main_score"
]
return
main_score
def
run_mteb_embed_task_st
(
model_name
,
tasks
):
from
sentence_transformers
import
SentenceTransformer
model
=
SentenceTransformer
(
model_name
)
return
run_mteb_embed_task
(
model
,
tasks
)
def
mteb_test_embed_models
(
hf_runner
,
vllm_runner
,
model_info
:
EmbedModelInfo
):
if
not
model_info
.
enable_test
:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest
.
skip
(
"Skipping test."
)
with
vllm_runner
(
model_info
.
name
,
task
=
"embed"
,
max_model_len
=
None
,
dtype
=
model_info
.
dtype
)
as
vllm_model
:
if
model_info
.
architecture
:
assert
(
model_info
.
architecture
in
vllm_model
.
model
.
llm_engine
.
model_config
.
architectures
)
vllm_main_score
=
run_mteb_embed_task
(
VllmMtebEncoder
(
vllm_model
),
MTEB_EMBED_TASKS
)
vllm_dtype
=
vllm_model
.
model
.
llm_engine
.
model_config
.
dtype
model_dtype
=
getattr
(
vllm_model
.
model
.
llm_engine
.
model_config
.
hf_config
,
"torch_dtype"
,
vllm_dtype
)
with
hf_runner
(
model_info
.
name
,
is_sentence_transformer
=
True
,
dtype
=
model_dtype
)
as
hf_model
:
st_main_score
=
run_mteb_embed_task
(
hf_model
,
MTEB_EMBED_TASKS
)
print
(
"VLLM:"
,
vllm_dtype
,
vllm_main_score
)
print
(
"SentenceTransformer:"
,
model_dtype
,
st_main_score
)
print
(
"Difference:"
,
st_main_score
-
vllm_main_score
)
assert
math
.
isclose
(
st_main_score
,
vllm_main_score
,
rel_tol
=
MTEB_EMBED_TOL
)
tests/models/language/pooling/test_nomic.py
0 → 100644
View file @
e4b87133
# SPDX-License-Identifier: Apache-2.0
import
pytest
from
...utils
import
EmbedModelInfo
,
run_embedding_correctness_test
MODELS
=
[
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v1"
,
architecture
=
"NomicBertModel"
,
dtype
=
"float32"
,
enable_test
=
True
),
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v1.5"
,
architecture
=
"NomicBertModel"
,
dtype
=
"float32"
,
enable_test
=
False
),
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v2-moe"
,
architecture
=
"NomicBertModel"
,
dtype
=
"float32"
,
enable_test
=
True
)
]
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
def
test_models_mteb
(
hf_runner
,
vllm_runner
,
model_info
:
EmbedModelInfo
)
->
None
:
from
.mteb_utils
import
mteb_test_embed_models
mteb_test_embed_models
(
hf_runner
,
vllm_runner
,
model_info
)
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
def
test_models_correctness
(
hf_runner
,
vllm_runner
,
model_info
:
EmbedModelInfo
,
example_prompts
)
->
None
:
if
not
model_info
.
enable_test
:
pytest
.
skip
(
"Skipping test."
)
with
vllm_runner
(
model_info
.
name
,
task
=
"embed"
,
dtype
=
model_info
.
dtype
,
max_model_len
=
None
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
encode
(
example_prompts
)
with
hf_runner
(
model_info
.
name
,
dtype
=
model_info
.
dtype
,
is_sentence_transformer
=
True
,
)
as
hf_model
:
run_embedding_correctness_test
(
hf_model
,
example_prompts
,
vllm_outputs
)
tests/models/language/pooling/test_snowflake_arctic_embed.py
View file @
e4b87133
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
pytest
from
...utils
import
EmbedModelInfo
,
check_embeddings_close
import
pytest
EMBEDDING_PROMPTS
=
[
from
...utils
import
EmbedModelInfo
,
run_embedding_correctness_test
'what is snowflake?'
,
'Where can I get the best tacos?'
,
'The Data Cloud!'
,
'Mexico City of Course!'
]
MODELS
=
[
MODELS
=
[
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-xs"
,
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-xs"
,
...
@@ -45,51 +41,34 @@ MODELS = [
...
@@ -45,51 +41,34 @@ MODELS = [
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_models_mteb
(
def
test_models
(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
example_prompts
,
model_info
:
EmbedModelInfo
,
model_info
:
EmbedModelInfo
,
dtype
:
str
,
monkeypatch
,
)
->
None
:
)
->
None
:
if
not
model_info
.
enable_test
:
from
.mteb_utils
import
mteb_test_embed_models
# A model family has many models with the same architecture,
mteb_test_embed_models
(
hf_runner
,
vllm_runner
,
model_info
)
# and we don't need to test each one.
pytest
.
skip
(
"Skipping test."
)
example_prompts
=
example_prompts
+
EMBEDDING_PROMPTS
vllm_extra_kwargs
=
{
"hf_overrides"
:
{
"is_matryoshka"
:
model_info
.
is_matryoshka
}
}
with
hf_runner
(
model_info
.
name
,
dtype
=
dtype
,
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
is_sentence_transformer
=
True
)
as
hf_model
:
def
test_models_correctness
(
hf_outputs
=
hf_model
.
encode
(
example_prompts
)
hf_runner
,
vllm_runner
,
model_info
:
EmbedModelInfo
,
example_prompts
,
)
->
None
:
if
not
model_info
.
enable_test
:
pytest
.
skip
(
"Skipping test."
)
with
vllm_runner
(
model_info
.
name
,
with
vllm_runner
(
model_info
.
name
,
task
=
"embed"
,
task
=
"embed"
,
dtype
=
dtype
,
dtype
=
model_info
.
dtype
,
max_model_len
=
None
,
max_model_len
=
None
)
as
vllm_model
:
**
vllm_extra_kwargs
)
as
vllm_model
:
assert
(
vllm_model
.
model
.
llm_engine
.
model_config
.
is_matryoshka
==
model_info
.
is_matryoshka
)
if
model_info
.
architecture
:
assert
(
model_info
.
architecture
in
vllm_model
.
model
.
llm_engine
.
model_config
.
architectures
)
vllm_outputs
=
vllm_model
.
encode
(
example_prompts
)
vllm_outputs
=
vllm_model
.
encode
(
example_prompts
)
check_embeddings_close
(
with
hf_runner
(
embeddings_0_lst
=
hf_outputs
,
model_info
.
name
,
embeddings_1_lst
=
vllm_outputs
,
dtype
=
model_info
.
dtype
,
name_0
=
"hf"
,
is_sentence_transformer
=
True
,
name_1
=
"vllm"
,
)
as
hf_model
:
tol
=
1e-2
,
run_embedding_correctness_test
(
hf_model
,
example_prompts
,
vllm_outputs
)
)
tests/models/utils.py
View file @
e4b87133
...
@@ -332,9 +332,10 @@ def matryoshka_fy(tensor: torch.Tensor, dimensions: int):
...
@@ -332,9 +332,10 @@ def matryoshka_fy(tensor: torch.Tensor, dimensions: int):
class
EmbedModelInfo
(
NamedTuple
):
class
EmbedModelInfo
(
NamedTuple
):
name
:
str
name
:
str
is_matryoshka
:
bool
is_matryoshka
:
bool
=
False
matryoshka_dimensions
:
Optional
[
list
[
int
]]
=
None
matryoshka_dimensions
:
Optional
[
list
[
int
]]
=
None
architecture
:
str
=
""
architecture
:
str
=
""
dtype
:
str
=
"auto"
enable_test
:
bool
=
True
enable_test
:
bool
=
True
...
...
vllm/model_executor/models/bert.py
View file @
e4b87133
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/bert_with_rope.py
0 → 100644
View file @
e4b87133
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/registry.py
View file @
e4b87133
...
@@ -126,7 +126,7 @@ _EMBEDDING_MODELS = {
...
@@ -126,7 +126,7 @@ _EMBEDDING_MODELS = {
"Gemma2Model"
:
(
"gemma2"
,
"Gemma2ForCausalLM"
),
"Gemma2Model"
:
(
"gemma2"
,
"Gemma2ForCausalLM"
),
"GlmForCausalLM"
:
(
"glm"
,
"GlmForCausalLM"
),
"GlmForCausalLM"
:
(
"glm"
,
"GlmForCausalLM"
),
"GritLM"
:
(
"gritlm"
,
"GritLM"
),
"GritLM"
:
(
"gritlm"
,
"GritLM"
),
"GteModel"
:
(
"bert
"
,
"GteEmbedding
Model"
),
"GteModel"
:
(
"bert
_with_rope"
,
"Gte
Model"
),
"InternLM2ForRewardModel"
:
(
"internlm2"
,
"InternLM2ForRewardModel"
),
"InternLM2ForRewardModel"
:
(
"internlm2"
,
"InternLM2ForRewardModel"
),
"JambaForSequenceClassification"
:
(
"jamba"
,
"JambaForSequenceClassification"
),
# noqa: E501
"JambaForSequenceClassification"
:
(
"jamba"
,
"JambaForSequenceClassification"
),
# noqa: E501
"LlamaModel"
:
(
"llama"
,
"LlamaForCausalLM"
),
"LlamaModel"
:
(
"llama"
,
"LlamaForCausalLM"
),
...
@@ -136,7 +136,7 @@ _EMBEDDING_MODELS = {
...
@@ -136,7 +136,7 @@ _EMBEDDING_MODELS = {
if
arch
==
"LlamaForCausalLM"
if
arch
==
"LlamaForCausalLM"
},
},
"MistralModel"
:
(
"llama"
,
"LlamaForCausalLM"
),
"MistralModel"
:
(
"llama"
,
"LlamaForCausalLM"
),
"NomicBertModel"
:
(
"bert"
,
"NomicBert
Embedding
Model"
),
"NomicBertModel"
:
(
"bert
_with_rope
"
,
"NomicBertModel"
),
"Phi3ForCausalLM"
:
(
"phi3"
,
"Phi3ForCausalLM"
),
"Phi3ForCausalLM"
:
(
"phi3"
,
"Phi3ForCausalLM"
),
"Qwen2Model"
:
(
"qwen2"
,
"Qwen2EmbeddingModel"
),
"Qwen2Model"
:
(
"qwen2"
,
"Qwen2EmbeddingModel"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
...
...
vllm/model_executor/models/roberta.py
View file @
e4b87133
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
itertools
import
itertools
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -19,6 +19,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
...
@@ -19,6 +19,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from
vllm.transformers_utils.config
import
(
from
vllm.transformers_utils.config
import
(
get_cross_encoder_activation_function
)
get_cross_encoder_activation_function
)
from
.bert_with_rope
import
BertWithRope
,
JinaRobertaModel
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
...
@@ -125,39 +126,20 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
...
@@ -125,39 +126,20 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
def
_build_model
(
self
,
def
_build_model
(
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
BertModel
:
prefix
:
str
=
""
)
->
Union
[
BertModel
,
BertWithRope
]
:
if
(
vllm_config
.
model_config
.
hf_config
.
position_embedding_type
==
if
(
vllm_config
.
model_config
.
hf_config
.
position_embedding_type
==
"rotary"
):
"rotary"
):
config
=
vllm_config
.
model_config
.
hf_config
return
JinaRobertaModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"rotary_dim"
:
getattr
(
config
,
"rotary_emb_dim"
,
head_dim
),
"max_position"
:
config
.
max_position_embeddings
,
"base"
:
config
.
rotary_emb_base
,
"rope_scaling"
:
getattr
(
config
,
"rope_scaling"
,
None
)
}
return
BertModel
(
vllm_config
=
vllm_config
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
prefix
)
else
:
else
:
return
BertModel
(
vllm_config
=
vllm_config
,
return
BertModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
prefix
=
prefix
,
embedding_class
=
RobertaEmbedding
)
embedding_class
=
RobertaEmbedding
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
if
getattr
(
self
.
config
,
"lora_rank"
,
0
)
>
0
:
scaling
=
self
.
config
.
lora_alpha
/
self
.
config
.
lora_rank
weights
=
jina_merge_lora_weights
(
weights
,
scaling
)
weights
=
self
.
hf_to_vllm_mapper
.
apply
(
weights
)
weights
=
self
.
hf_to_vllm_mapper
.
apply
(
weights
)
# Separate weights in "roberta"-prefixed and all else (not in memory).
# Separate weights in "roberta"-prefixed and all else (not in memory).
# For use with models like FacebookAI/roberta-base.
# For use with models like FacebookAI/roberta-base.
bert_weights
,
task_weights
=
roberta_task_weights_filter
(
weights
)
bert_weights
,
task_weights
=
roberta_task_weights_filter
(
weights
)
bert_weights
=
jina_to_vllm_mapper
.
apply
(
bert_weights
)
loaded
=
self
.
model
.
load_weights
(
bert_weights
)
loaded
=
self
.
model
.
load_weights
(
bert_weights
)
if
not
len
(
loaded
):
if
not
len
(
loaded
):
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
...
@@ -178,6 +160,18 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
...
@@ -178,6 +160,18 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
_pooler: An instance of Pooler used for pooling operations.
_pooler: An instance of Pooler used for pooling operations.
"""
"""
jina_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
'emb_ln'
:
"embeddings.LayerNorm"
,
'layers'
:
"layer"
,
'mixer.Wqkv'
:
"attention.self.qkv_proj"
,
'mixer.out_proj'
:
"attention.output.dense"
,
'norm1'
:
"attention.output.LayerNorm"
,
'mlp.fc1'
:
"intermediate.dense"
,
'mlp.fc2'
:
"output.dense"
,
'norm2'
:
"output.LayerNorm"
,
})
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
...
@@ -195,7 +189,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
...
@@ -195,7 +189,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
bert_weights
,
task_weights
=
roberta_task_weights_filter
(
weights
)
bert_weights
,
task_weights
=
roberta_task_weights_filter
(
weights
)
bert_weights
=
jina_to_vllm_mapper
.
apply
(
bert_weights
)
bert_weights
=
self
.
jina_to_vllm_mapper
.
apply
(
bert_weights
)
self
.
roberta
.
load_weights
(
bert_weights
)
self
.
roberta
.
load_weights
(
bert_weights
)
...
@@ -276,57 +270,3 @@ def roberta_task_weights_filter(
...
@@ -276,57 +270,3 @@ def roberta_task_weights_filter(
return
encoder_decoder_weights
(),
((
n
,
w
)
for
n
,
w
in
all_weights2
return
encoder_decoder_weights
(),
((
n
,
w
)
for
n
,
w
in
all_weights2
if
not
n
.
startswith
(
"roberta."
))
if
not
n
.
startswith
(
"roberta."
))
jina_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
'emb_ln'
:
"embeddings.LayerNorm"
,
'layers'
:
"layer"
,
'mixer.Wqkv'
:
"attention.self.qkv_proj"
,
'mixer.out_proj'
:
"attention.output.dense"
,
'norm1'
:
"attention.output.LayerNorm"
,
'mlp.fc1'
:
"intermediate.dense"
,
'mlp.fc2'
:
"output.dense"
,
'norm2'
:
"output.LayerNorm"
,
})
@
torch
.
inference_mode
()
def
jina_merge_lora_weights
(
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
scaling
:
float
=
1.0
):
# use for jina-embeddings-v3
# Merge Lora weights into a single weight tensor.
# This is a temporary solution until we have a better way to handle
weights
=
{
name
:
weight
for
name
,
weight
in
weights
}
o
=
".original"
a
=
".0.lora_A"
b
=
".0.lora_B"
# text-matching
i
=
-
1
for
name
in
list
(
weights
.
keys
()):
if
o
in
name
:
dtype
=
weights
[
name
].
dtype
shape
=
weights
[
name
].
shape
weight_name
=
name
[:
-
len
(
o
)]
if
"embeddings"
in
weight_name
:
B
=
weights
[
weight_name
+
a
][
i
].
cuda
().
float
()
A
=
weights
[
weight_name
+
b
][
i
].
cuda
().
float
()
else
:
B
=
weights
[
weight_name
+
b
][
i
].
cuda
().
float
()
A
=
weights
[
weight_name
+
a
][
i
].
cuda
().
float
()
weight
=
(
weights
[
weight_name
+
o
].
cuda
()
+
torch
.
matmul
(
B
,
A
).
view
(
shape
)
*
scaling
)
weight
=
weight
.
cpu
().
to
(
dtype
)
weights
[
weight_name
.
replace
(
".parametrizations"
,
""
)]
=
weight
del
weights
[
weight_name
+
o
],
weights
[
weight_name
+
a
],
weights
[
weight_name
+
b
]
return
[(
name
,
weight
)
for
name
,
weight
in
weights
.
items
()]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment