Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1f5d13ab
"vllm/vscode:/vscode.git/clone" did not exist on "60f76243344d2d3deca5e5ecdade547acc7fed50"
Unverified
Commit
1f5d13ab
authored
Apr 08, 2025
by
wang.yuqi
Committed by
GitHub
Apr 08, 2025
Browse files
[New Model]: jinaai/jina-embeddings-v3 (#16120)
parent
90cb44eb
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
297 additions
and
86 deletions
+297
-86
examples/offline_inference/embed_jina_embeddings_v3.py
examples/offline_inference/embed_jina_embeddings_v3.py
+50
-0
tests/conftest.py
tests/conftest.py
+3
-2
tests/models/embedding/language/test_jina.py
tests/models/embedding/language/test_jina.py
+61
-3
vllm/config.py
vllm/config.py
+5
-0
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+49
-17
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+129
-64
No files found.
examples/offline_inference/embed_jina_embeddings_v3.py
0 → 100644
View file @
1f5d13ab
# SPDX-License-Identifier: Apache-2.0
from
argparse
import
Namespace
from
vllm
import
LLM
,
EngineArgs
from
vllm.utils
import
FlexibleArgumentParser
def
main
(
args
:
Namespace
):
# Sample prompts.
prompts
=
[
"Follow the white rabbit."
,
# English
"Sigue al conejo blanco."
,
# Spanish
"Suis le lapin blanc."
,
# French
"跟着白兔走。"
,
# Chinese
"اتبع الأرنب الأبيض."
,
# Arabic
"Folge dem weißen Kaninchen."
,
# German
]
# Create an LLM.
# You should pass task="embed" for embedding models
model
=
LLM
(
**
vars
(
args
))
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
# Only text matching task is supported for now. See #16120
outputs
=
model
.
embed
(
prompts
)
# Print the outputs.
print
(
"
\n
Generated Outputs:"
)
print
(
"Only text matching task is supported for now. See #16120"
)
print
(
"-"
*
60
)
for
prompt
,
output
in
zip
(
prompts
,
outputs
):
embeds
=
output
.
outputs
.
embedding
embeds_trimmed
=
((
str
(
embeds
[:
16
])[:
-
1
]
+
", ...]"
)
if
len
(
embeds
)
>
16
else
embeds
)
print
(
f
"Prompt:
{
prompt
!
r
}
\n
"
f
"Embeddings for text matching:
{
embeds_trimmed
}
"
f
"(size=
{
len
(
embeds
)
}
)"
)
print
(
"-"
*
60
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
=
EngineArgs
.
add_cli_args
(
parser
)
# Set example specific arguments
parser
.
set_defaults
(
model
=
"jinaai/jina-embeddings-v3"
,
task
=
"embed"
,
trust_remote_code
=
True
)
args
=
parser
.
parse_args
()
main
(
args
)
tests/conftest.py
View file @
1f5d13ab
...
@@ -671,8 +671,9 @@ class HfRunner:
...
@@ -671,8 +671,9 @@ class HfRunner:
return
[(
output_ids
,
output_str
,
output_logprobs
)
return
[(
output_ids
,
output_str
,
output_logprobs
)
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
def
encode
(
self
,
prompts
:
list
[
str
])
->
list
[
list
[
torch
.
Tensor
]]:
def
encode
(
self
,
prompts
:
list
[
str
],
*
args
,
return
self
.
model
.
encode
(
prompts
)
**
kwargs
)
->
list
[
list
[
torch
.
Tensor
]]:
return
self
.
model
.
encode
(
prompts
,
*
args
,
**
kwargs
)
def
predict
(
self
,
prompts
:
list
[
list
[
str
]])
->
torch
.
Tensor
:
def
predict
(
self
,
prompts
:
list
[
list
[
str
]])
->
torch
.
Tensor
:
return
self
.
model
.
predict
(
prompts
,
convert_to_tensor
=
True
)
return
self
.
model
.
predict
(
prompts
,
convert_to_tensor
=
True
)
...
...
tests/models/embedding/language/test_jina
_reranker_v2
.py
→
tests/models/embedding/language/test_jina.py
View file @
1f5d13ab
...
@@ -2,13 +2,15 @@
...
@@ -2,13 +2,15 @@
# ruff: noqa: E501
# ruff: noqa: E501
"""Compare the scoring outputs of HF and vLLM models.
"""Compare the scoring outputs of HF and vLLM models.
Run `pytest tests/models/embedding/language/test_jina
_reranker_v2
.py`.
Run `pytest tests/models/embedding/language/test_jina.py`.
"""
"""
import
math
import
math
import
pytest
import
pytest
MODELS
=
[
from
tests.models.embedding.utils
import
check_embeddings_close
SCORING_MODELS
=
[
"jinaai/jina-reranker-v2-base-multilingual"
,
# Roberta
"jinaai/jina-reranker-v2-base-multilingual"
,
# Roberta
]
]
...
@@ -27,8 +29,21 @@ TEXTS_2 = [
...
@@ -27,8 +29,21 @@ TEXTS_2 = [
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています"
,
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています"
,
]
]
EMBEDDING_MODELS
=
[
"jinaai/jina-embeddings-v3"
,
]
EMBEDDING_PROMPTS
=
[
"Follow the white rabbit."
,
# English
"Sigue al conejo blanco."
,
# Spanish
"Suis le lapin blanc."
,
# French
"跟着白兔走。"
,
# Chinese
"اتبع الأرنب الأبيض."
,
# Arabic
"Folge dem weißen Kaninchen."
,
# German
]
@
pytest
.
fixture
(
scope
=
"module"
,
params
=
MODELS
)
@
pytest
.
fixture
(
scope
=
"module"
,
params
=
SCORING_
MODELS
)
def
model_name
(
request
):
def
model_name
(
request
):
yield
request
.
param
yield
request
.
param
...
@@ -68,3 +83,46 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
...
@@ -68,3 +83,46 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
assert
math
.
isclose
(
hf_outputs
[
0
],
vllm_outputs
[
0
],
rel_tol
=
0.01
)
assert
math
.
isclose
(
hf_outputs
[
0
],
vllm_outputs
[
0
],
rel_tol
=
0.01
)
assert
math
.
isclose
(
hf_outputs
[
1
],
vllm_outputs
[
1
],
rel_tol
=
0.01
)
assert
math
.
isclose
(
hf_outputs
[
1
],
vllm_outputs
[
1
],
rel_tol
=
0.01
)
@
pytest
.
fixture
(
scope
=
"module"
,
params
=
EMBEDDING_MODELS
)
def
emb_model_name
(
request
):
yield
request
.
param
def
test_is_matryoshka
(
vllm_runner
,
emb_model_name
):
with
vllm_runner
(
emb_model_name
,
task
=
"embed"
,
max_model_len
=
None
)
as
vllm_model
:
assert
vllm_model
.
model
.
llm_engine
.
model_config
.
is_matryoshka
@
pytest
.
mark
.
parametrize
(
"model"
,
EMBEDDING_MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_embeddings
(
hf_runner
,
vllm_runner
,
model
,
dtype
:
str
,
monkeypatch
,
)
->
None
:
example_prompts
=
EMBEDDING_PROMPTS
with
hf_runner
(
model
,
dtype
=
dtype
,
is_sentence_transformer
=
True
,
)
as
hf_model
:
hf_outputs
=
hf_model
.
encode
(
example_prompts
,
task
=
"text-matching"
)
with
vllm_runner
(
model
,
task
=
"embed"
,
dtype
=
dtype
,
max_model_len
=
None
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
encode
(
example_prompts
)
check_embeddings_close
(
embeddings_0_lst
=
hf_outputs
,
embeddings_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
tol
=
1e-2
,
)
vllm/config.py
View file @
1f5d13ab
...
@@ -1130,6 +1130,11 @@ class ModelConfig:
...
@@ -1130,6 +1130,11 @@ class ModelConfig:
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
return
ModelRegistry
.
is_v1_compatible
(
architectures
)
return
ModelRegistry
.
is_v1_compatible
(
architectures
)
@
property
def
is_matryoshka
(
self
)
->
bool
:
return
(
hasattr
(
self
.
hf_config
,
"matryoshka_dimensions"
)
or
getattr
(
self
.
hf_config
,
"is_matryoshka"
,
False
))
class
CacheConfig
:
class
CacheConfig
:
"""Configuration for the KV cache.
"""Configuration for the KV cache.
...
...
vllm/model_executor/models/bert.py
View file @
1f5d13ab
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.pooler
import
(
CrossEncodingPooler
,
Pooler
,
from
vllm.model_executor.layers.pooler
import
(
CrossEncodingPooler
,
Pooler
,
PoolingType
)
PoolingType
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -38,19 +39,24 @@ class BertEmbedding(nn.Module):
...
@@ -38,19 +39,24 @@ class BertEmbedding(nn.Module):
self
.
size
=
config
.
hidden_size
self
.
size
=
config
.
hidden_size
self
.
word_embeddings
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
word_embeddings
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
config
.
hidden_size
)
self
.
position_embeddings
=
VocabParallelEmbedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
)
self
.
token_type_embeddings
=
VocabParallelEmbedding
(
self
.
token_type_embeddings
=
VocabParallelEmbedding
(
config
.
type_vocab_size
,
config
.
hidden_size
)
config
.
type_vocab_size
,
config
.
hidden_size
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
eps
=
config
.
layer_norm_eps
)
self
.
position_ids
=
nn
.
Parameter
(
torch
.
empty
((
1
,
config
.
max_position_embeddings
)),
)
self
.
position_embedding_type
=
config
.
position_embedding_type
self
.
position_embedding_type
=
config
.
position_embedding_type
if
self
.
position_embedding_type
!=
"absolute"
:
if
self
.
position_embedding_type
==
"absolute"
:
raise
ValueError
(
"Only 'absolute' position_embedding_type"
+
self
.
position_embeddings
=
VocabParallelEmbedding
(
" is supported"
)
config
.
max_position_embeddings
,
config
.
hidden_size
)
self
.
position_ids
=
nn
.
Parameter
(
torch
.
empty
((
1
,
config
.
max_position_embeddings
)),
)
elif
self
.
position_embedding_type
==
"rotary"
:
self
.
position_embeddings
=
None
self
.
position_ids
=
None
else
:
raise
ValueError
(
"Only 'absolute' and 'rotary' "
+
"position_embedding_type is supported"
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -64,9 +70,6 @@ class BertEmbedding(nn.Module):
...
@@ -64,9 +70,6 @@ class BertEmbedding(nn.Module):
# Input embeddings.
# Input embeddings.
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
# Position embeddings.
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
if
token_type_ids
is
None
:
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
input_shape
,
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
...
@@ -74,7 +77,12 @@ class BertEmbedding(nn.Module):
...
@@ -74,7 +77,12 @@ class BertEmbedding(nn.Module):
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
embeddings
=
inputs_embeds
+
token_type_embeddings
+
position_embeddings
embeddings
=
inputs_embeds
+
token_type_embeddings
if
self
.
position_embedding_type
==
"absolute"
:
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
+=
position_embeddings
embeddings
=
self
.
LayerNorm
(
embeddings
)
embeddings
=
self
.
LayerNorm
(
embeddings
)
return
embeddings
return
embeddings
...
@@ -98,7 +106,10 @@ class BertPooler(nn.Module):
...
@@ -98,7 +106,10 @@ class BertPooler(nn.Module):
@
support_torch_compile
@
support_torch_compile
class
BertEncoder
(
nn
.
Module
):
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
...
@@ -107,16 +118,18 @@ class BertEncoder(nn.Module):
...
@@ -107,16 +118,18 @@ class BertEncoder(nn.Module):
BertLayer
(
config
=
config
,
BertLayer
(
config
=
config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.layer.
{
layer_idx
}
"
)
prefix
=
f
"
{
prefix
}
.layer.
{
layer_idx
}
"
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
])
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
for
layer
in
self
.
layer
:
for
layer
in
self
.
layer
:
hidden_states
=
layer
(
hidden_states
)
hidden_states
=
layer
(
positions
,
hidden_states
)
return
hidden_states
return
hidden_states
...
@@ -126,6 +139,7 @@ class BertLayer(nn.Module):
...
@@ -126,6 +139,7 @@ class BertLayer(nn.Module):
config
:
BertConfig
,
config
:
BertConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
...
@@ -135,6 +149,7 @@ class BertLayer(nn.Module):
...
@@ -135,6 +149,7 @@ class BertLayer(nn.Module):
layer_norm_eps
=
config
.
layer_norm_eps
,
layer_norm_eps
=
config
.
layer_norm_eps
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.attention"
)
prefix
=
f
"
{
prefix
}
.attention"
)
self
.
intermediate
=
BertIntermediate
(
self
.
intermediate
=
BertIntermediate
(
...
@@ -150,8 +165,8 @@ class BertLayer(nn.Module):
...
@@ -150,8 +165,8 @@ class BertLayer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.output"
)
prefix
=
f
"
{
prefix
}
.output"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
):
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
):
attn_output
=
self
.
attention
(
hidden_states
)
attn_output
=
self
.
attention
(
positions
,
hidden_states
)
intermediate_output
=
self
.
intermediate
(
attn_output
)
intermediate_output
=
self
.
intermediate
(
attn_output
)
output
=
self
.
output
(
intermediate_output
,
attn_output
)
output
=
self
.
output
(
intermediate_output
,
attn_output
)
return
output
return
output
...
@@ -166,6 +181,7 @@ class BertAttention(nn.Module):
...
@@ -166,6 +181,7 @@ class BertAttention(nn.Module):
layer_norm_eps
:
float
,
layer_norm_eps
:
float
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -174,6 +190,7 @@ class BertAttention(nn.Module):
...
@@ -174,6 +190,7 @@ class BertAttention(nn.Module):
num_attention_heads
=
num_attention_heads
,
num_attention_heads
=
num_attention_heads
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.output"
)
prefix
=
f
"
{
prefix
}
.output"
)
self
.
output
=
BertSelfOutput
(
hidden_size
=
hidden_size
,
self
.
output
=
BertSelfOutput
(
hidden_size
=
hidden_size
,
...
@@ -183,9 +200,10 @@ class BertAttention(nn.Module):
...
@@ -183,9 +200,10 @@ class BertAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
self_output
=
self
.
self
(
hidden_states
)
self_output
=
self
.
self
(
positions
,
hidden_states
)
return
self
.
output
(
self_output
,
hidden_states
)
return
self
.
output
(
self_output
,
hidden_states
)
...
@@ -197,6 +215,7 @@ class BertSelfAttention(nn.Module):
...
@@ -197,6 +215,7 @@ class BertSelfAttention(nn.Module):
num_attention_heads
:
int
,
num_attention_heads
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -225,6 +244,11 @@ class BertSelfAttention(nn.Module):
...
@@ -225,6 +244,11 @@ class BertSelfAttention(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
)
prefix
=
f
"
{
prefix
}
.qkv_proj"
)
if
rotary_kwargs
:
self
.
rotary_emb
=
get_rope
(
**
rotary_kwargs
)
else
:
self
.
rotary_emb
=
None
self
.
attn
=
Attention
(
num_heads
=
self
.
num_heads
,
self
.
attn
=
Attention
(
num_heads
=
self
.
num_heads
,
head_size
=
self
.
head_dim
,
head_size
=
self
.
head_dim
,
scale
=
self
.
scaling
,
scale
=
self
.
scaling
,
...
@@ -236,10 +260,15 @@ class BertSelfAttention(nn.Module):
...
@@ -236,10 +260,15 @@ class BertSelfAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
rotary_emb
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
output
=
self
.
attn
(
q
,
k
,
v
)
output
=
self
.
attn
(
q
,
k
,
v
)
return
output
return
output
...
@@ -321,11 +350,13 @@ class BertModel(nn.Module, SupportsQuant):
...
@@ -321,11 +350,13 @@ class BertModel(nn.Module, SupportsQuant):
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
embedding_class
:
type
=
BertEmbedding
,
embedding_class
:
type
=
BertEmbedding
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
add_pooling_layer
:
bool
=
False
):
add_pooling_layer
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
self
.
embeddings
=
embedding_class
(
config
)
self
.
embeddings
=
embedding_class
(
config
)
self
.
encoder
=
BertEncoder
(
vllm_config
=
vllm_config
,
self
.
encoder
=
BertEncoder
(
vllm_config
=
vllm_config
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.encoder"
)
prefix
=
f
"
{
prefix
}
.encoder"
)
self
.
pooler
=
BertPooler
(
config
)
if
add_pooling_layer
else
None
self
.
pooler
=
BertPooler
(
config
)
if
add_pooling_layer
else
None
...
@@ -347,7 +378,7 @@ class BertModel(nn.Module, SupportsQuant):
...
@@ -347,7 +378,7 @@ class BertModel(nn.Module, SupportsQuant):
seq_lens
=
attn_metadata
.
seq_lens_tensor
,
seq_lens
=
attn_metadata
.
seq_lens_tensor
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
token_type_ids
=
token_type_ids
)
return
self
.
encoder
(
hidden_states
)
return
self
.
encoder
(
position_ids
,
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
...
@@ -401,6 +432,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
...
@@ -401,6 +432,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
model
=
self
.
_build_model
(
vllm_config
=
vllm_config
,
self
.
model
=
self
.
_build_model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_pooler
=
self
.
_build_pooler
(
pooler_config
)
self
.
_pooler
=
self
.
_build_pooler
(
pooler_config
)
...
...
vllm/model_executor/models/roberta.py
View file @
1f5d13ab
...
@@ -22,30 +22,6 @@ from vllm.transformers_utils.config import (
...
@@ -22,30 +22,6 @@ from vllm.transformers_utils.config import (
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
def
roberta_task_weights_filter
(
all_weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]
)
->
Tuple
[
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]]:
"""
Separate task-specific weights that are applied on top
of the encoder-decoder bert base.
To do so, return two generators over the original iterator.
Also, remove the "roberta." prefix to make it loadable
from vanilla BertModel.
"""
# Copy of a lazy iterator without in-memory overhead so both
# iterators can be iterated upon independently.
all_weights1
,
all_weights2
=
itertools
.
tee
(
all_weights
)
def
encoder_decoder_weights
():
for
name
,
weight
in
all_weights1
:
if
name
.
startswith
(
"roberta."
):
yield
(
name
[
len
(
"roberta."
):],
weight
)
return
encoder_decoder_weights
(),
((
n
,
w
)
for
n
,
w
in
all_weights2
if
not
n
.
startswith
(
"roberta."
))
class
RobertaEmbedding
(
nn
.
Module
):
class
RobertaEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
RobertaConfig
):
def
__init__
(
self
,
config
:
RobertaConfig
):
...
@@ -119,30 +95,6 @@ class RobertaEmbedding(nn.Module):
...
@@ -119,30 +95,6 @@ class RobertaEmbedding(nn.Module):
return
embeddings
return
embeddings
# Adapted from transformers
def
create_position_ids_from_input_ids
(
input_ids
,
padding_idx
,
past_key_values_length
=
0
):
"""
Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
x: torch.Tensor x:
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA.
mask
=
input_ids
.
ne
(
padding_idx
).
int
()
incremental_indices
=
(
torch
.
cumsum
(
mask
,
dim
=
0
).
type_as
(
mask
)
+
past_key_values_length
)
*
mask
return
incremental_indices
.
long
()
+
padding_idx
# Adapted from transformers
# Adapted from transformers
class
RobertaClassificationHead
(
nn
.
Module
):
class
RobertaClassificationHead
(
nn
.
Module
):
"""Head for sentence-level classification tasks."""
"""Head for sentence-level classification tasks."""
...
@@ -174,15 +126,38 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
...
@@ -174,15 +126,38 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
def
_build_model
(
self
,
def
_build_model
(
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
BertModel
:
prefix
:
str
=
""
)
->
BertModel
:
return
BertModel
(
vllm_config
=
vllm_config
,
if
(
vllm_config
.
model_config
.
hf_config
.
position_embedding_type
==
prefix
=
prefix
,
"rotary"
):
embedding_class
=
RobertaEmbedding
)
config
=
vllm_config
.
model_config
.
hf_config
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"rotary_dim"
:
getattr
(
config
,
"rotary_emb_dim"
,
head_dim
),
"max_position"
:
config
.
max_position_embeddings
,
"base"
:
config
.
rotary_emb_base
,
"rope_scaling"
:
getattr
(
config
,
"rope_scaling"
,
None
)
}
return
BertModel
(
vllm_config
=
vllm_config
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
prefix
)
else
:
return
BertModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
embedding_class
=
RobertaEmbedding
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
if
getattr
(
self
.
config
,
"lora_rank"
,
0
)
>
0
:
scaling
=
self
.
config
.
lora_alpha
/
self
.
config
.
lora_rank
weights
=
jina_merge_lora_weights
(
weights
,
scaling
)
weights
=
self
.
hf_to_vllm_mapper
.
apply
(
weights
)
weights
=
self
.
hf_to_vllm_mapper
.
apply
(
weights
)
# Separate weights in "roberta"-prefixed and all else (not in memory).
# Separate weights in "roberta"-prefixed and all else (not in memory).
# For use with models like FacebookAI/roberta-base.
# For use with models like FacebookAI/roberta-base.
bert_weights
,
task_weights
=
roberta_task_weights_filter
(
weights
)
bert_weights
,
task_weights
=
roberta_task_weights_filter
(
weights
)
bert_weights
=
jina_to_vllm_mapper
.
apply
(
bert_weights
)
loaded
=
self
.
model
.
load_weights
(
bert_weights
)
loaded
=
self
.
model
.
load_weights
(
bert_weights
)
if
not
len
(
loaded
):
if
not
len
(
loaded
):
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
...
@@ -203,18 +178,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
...
@@ -203,18 +178,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
_pooler: An instance of Pooler used for pooling operations.
_pooler: An instance of Pooler used for pooling operations.
"""
"""
jina_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
'emb_ln'
:
"embeddings.LayerNorm"
,
'layers'
:
"layer"
,
'mixer.Wqkv'
:
"attention.self.qkv_proj"
,
'mixer.out_proj'
:
"attention.output.dense"
,
'norm1'
:
"attention.output.LayerNorm"
,
'mlp.fc1'
:
"intermediate.dense"
,
'mlp.fc2'
:
"output.dense"
,
'norm2'
:
"output.LayerNorm"
,
})
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
...
@@ -232,7 +195,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
...
@@ -232,7 +195,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
bert_weights
,
task_weights
=
roberta_task_weights_filter
(
weights
)
bert_weights
,
task_weights
=
roberta_task_weights_filter
(
weights
)
bert_weights
=
self
.
jina_to_vllm_mapper
.
apply
(
bert_weights
)
bert_weights
=
jina_to_vllm_mapper
.
apply
(
bert_weights
)
self
.
roberta
.
load_weights
(
bert_weights
)
self
.
roberta
.
load_weights
(
bert_weights
)
...
@@ -265,3 +228,105 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
...
@@ -265,3 +228,105 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
token_type_ids
=
token_type_ids
)
token_type_ids
=
token_type_ids
)
# Adapted from transformers
def
create_position_ids_from_input_ids
(
input_ids
,
padding_idx
,
past_key_values_length
=
0
):
"""
Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
x: torch.Tensor x:
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA.
mask
=
input_ids
.
ne
(
padding_idx
).
int
()
incremental_indices
=
(
torch
.
cumsum
(
mask
,
dim
=
0
).
type_as
(
mask
)
+
past_key_values_length
)
*
mask
return
incremental_indices
.
long
()
+
padding_idx
def
roberta_task_weights_filter
(
all_weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]
)
->
Tuple
[
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]]:
"""
Separate task-specific weights that are applied on top
of the encoder-decoder bert base.
To do so, return two generators over the original iterator.
Also, remove the "roberta." prefix to make it loadable
from vanilla BertModel.
"""
# Copy of a lazy iterator without in-memory overhead so both
# iterators can be iterated upon independently.
all_weights1
,
all_weights2
=
itertools
.
tee
(
all_weights
)
def
encoder_decoder_weights
():
for
name
,
weight
in
all_weights1
:
if
name
.
startswith
(
"roberta."
):
yield
(
name
[
len
(
"roberta."
):],
weight
)
return
encoder_decoder_weights
(),
((
n
,
w
)
for
n
,
w
in
all_weights2
if
not
n
.
startswith
(
"roberta."
))
jina_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
'emb_ln'
:
"embeddings.LayerNorm"
,
'layers'
:
"layer"
,
'mixer.Wqkv'
:
"attention.self.qkv_proj"
,
'mixer.out_proj'
:
"attention.output.dense"
,
'norm1'
:
"attention.output.LayerNorm"
,
'mlp.fc1'
:
"intermediate.dense"
,
'mlp.fc2'
:
"output.dense"
,
'norm2'
:
"output.LayerNorm"
,
})
@
torch
.
inference_mode
()
def
jina_merge_lora_weights
(
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
scaling
:
float
=
1.0
):
# use for jina-embeddings-v3
# Merge Lora weights into a single weight tensor.
# This is a temporary solution until we have a better way to handle
weights
=
{
name
:
weight
for
name
,
weight
in
weights
}
o
=
".original"
a
=
".0.lora_A"
b
=
".0.lora_B"
# text-matching
i
=
-
1
for
name
in
list
(
weights
.
keys
()):
if
o
in
name
:
dtype
=
weights
[
name
].
dtype
shape
=
weights
[
name
].
shape
weight_name
=
name
[:
-
len
(
o
)]
if
"embeddings"
in
weight_name
:
B
=
weights
[
weight_name
+
a
][
i
].
cuda
().
float
()
A
=
weights
[
weight_name
+
b
][
i
].
cuda
().
float
()
else
:
B
=
weights
[
weight_name
+
b
][
i
].
cuda
().
float
()
A
=
weights
[
weight_name
+
a
][
i
].
cuda
().
float
()
weight
=
(
weights
[
weight_name
+
o
].
cuda
()
+
torch
.
matmul
(
B
,
A
).
view
(
shape
)
*
scaling
)
weight
=
weight
.
cpu
().
to
(
dtype
)
weights
[
weight_name
.
replace
(
".parametrizations"
,
""
)]
=
weight
del
weights
[
weight_name
+
o
],
weights
[
weight_name
+
a
],
weights
[
weight_name
+
b
]
return
[(
name
,
weight
)
for
name
,
weight
in
weights
.
items
()]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment