Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
15986f59
Unverified
Commit
15986f59
authored
Oct 04, 2024
by
Xin Yang
Committed by
GitHub
Oct 05, 2024
Browse files
[Model] Support Gemma2 embedding model (#9004)
parent
53b3a330
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
99 additions
and
3 deletions
+99
-3
tests/conftest.py
tests/conftest.py
+1
-0
tests/models/embedding/language/test_embedding.py
tests/models/embedding/language/test_embedding.py
+10
-1
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+5
-2
vllm/model_executor/models/gemma2_embedding.py
vllm/model_executor/models/gemma2_embedding.py
+82
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
No files found.
tests/conftest.py
View file @
15986f59
...
...
@@ -277,6 +277,7 @@ class HfRunner:
SentenceTransformer
(
model_name
,
device
=
"cpu"
,
trust_remote_code
=
True
,
).
to
(
dtype
=
torch_dtype
))
else
:
model_kwargs
=
model_kwargs
if
model_kwargs
is
not
None
else
{}
...
...
tests/models/embedding/language/test_embedding.py
View file @
15986f59
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
Run `pytest tests/models/
test_llama
_embedding.py`.
Run `pytest tests/models/
embedding/language/test
_embedding.py`.
"""
import
pytest
import
torch
...
...
@@ -8,6 +8,7 @@ import torch.nn.functional as F
MODELS
=
[
"intfloat/e5-mistral-7b-instruct"
,
"BAAI/bge-multilingual-gemma2"
,
]
...
...
@@ -28,6 +29,14 @@ def test_models(
model
:
str
,
dtype
:
str
,
)
->
None
:
# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see:
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
# This makes the input_ids different between hf_model and vllm_model.
# So we need to strip the input texts to avoid test failing.
example_prompts
=
[
str
(
s
).
strip
()
for
s
in
example_prompts
]
with
hf_runner
(
model
,
dtype
=
dtype
,
is_embedding_model
=
True
)
as
hf_model
:
hf_outputs
=
hf_model
.
encode
(
example_prompts
)
...
...
vllm/model_executor/models/gemma2.py
View file @
15986f59
...
...
@@ -278,11 +278,14 @@ class Gemma2Model(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
*=
self
.
normalizer
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
...
...
vllm/model_executor/models/gemma2_embedding.py
0 → 100644
View file @
15986f59
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
vllm.attention
import
AttentionMetadata
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.gemma2
import
Gemma2Model
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
class
Gemma2EmbeddingModel
(
nn
.
Module
):
"""A model that uses Gemma2 with additional embedding functionalities.
This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def
__init__
(
self
,
**
kwargs
,
)
->
None
:
super
().
__init__
()
self
.
model
=
Gemma2Model
(
**
kwargs
)
self
.
_pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
model
.
forward
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
model
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/registry.py
View file @
15986f59
...
...
@@ -83,6 +83,7 @@ _GENERATION_MODELS = {
_EMBEDDING_MODELS
=
{
"MistralModel"
:
(
"llama_embedding"
,
"LlamaEmbeddingModel"
),
"Qwen2ForRewardModel"
:
(
"qwen2_rm"
,
"Qwen2ForRewardModel"
),
"Gemma2Model"
:
(
"gemma2_embedding"
,
"Gemma2EmbeddingModel"
),
}
_MULTIMODAL_MODELS
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment