Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
13370712
Unverified
Commit
13370712
authored
Dec 01, 2024
by
Cyrus Leung
Committed by
GitHub
Dec 01, 2024
Browse files
[Model] Replace embedding models with pooling adapter (#10769)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
7e4bbda5
Changes
32
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
262 additions
and
227 deletions
+262
-227
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-2
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+14
-1
tests/conftest.py
tests/conftest.py
+0
-1
tests/models/embedding/language/test_embedding.py
tests/models/embedding/language/test_embedding.py
+5
-0
tests/models/test_registry.py
tests/models/test_registry.py
+14
-17
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
...dd_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
+40
-5
tests/test_config.py
tests/test_config.py
+1
-2
vllm/config.py
vllm/config.py
+25
-0
vllm/inputs/registry.py
vllm/inputs/registry.py
+8
-8
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+1
-3
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+14
-4
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+14
-4
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+98
-0
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+3
-2
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+2
-56
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+3
-2
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+7
-95
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+3
-2
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+5
-21
vllm/model_executor/models/llava_next_video.py
vllm/model_executor/models/llava_next_video.py
+3
-2
No files found.
.buildkite/test-pipeline.yaml
View file @
13370712
...
...
@@ -334,7 +334,6 @@ steps:
commands
:
-
pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
-
pytest -v -s models/embedding/language -m core_model
-
pytest -v -s models/embedding/vision_language -m core_model
-
label
:
Language Models Test (Extended)
# 50min
optional
:
true
...
...
@@ -346,7 +345,6 @@ steps:
commands
:
-
pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
-
pytest -v -s models/embedding/language -m 'not core_model'
-
pytest -v -s models/embedding/vision_language -m 'not core_model'
-
label
:
Multi-Modal Models Test (Standard)
# 26min
#mirror_hardwares: [amd]
...
...
@@ -359,6 +357,7 @@ steps:
commands
:
-
pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
-
pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
-
pytest -v -s models/embedding/vision_language -m core_model
-
pytest -v -s models/encoder_decoder/language -m core_model
-
pytest -v -s models/encoder_decoder/vision_language -m core_model
...
...
@@ -376,6 +375,7 @@ steps:
# https://github.com/huggingface/transformers/issues/34307
-
pytest -v -s models/decoder_only/vision_language/test_phi3v.py
-
pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
-
pytest -v -s models/embedding/vision_language -m 'not core_model'
-
pytest -v -s models/encoder_decoder/language -m 'not core_model'
-
pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
...
...
docs/source/models/supported_models.rst
View file @
13370712
...
...
@@ -357,7 +357,7 @@ Text Embedding
- ✅︎
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
- Qwen2-based
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`
(see note)
, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- ✅︎
- ✅︎
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
...
...
@@ -378,6 +378,10 @@ Text Embedding
.. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`.
.. note::
:code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`.
.. note::
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.
...
...
@@ -397,12 +401,21 @@ Reward Modeling
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`LlamaForCausalLM`
- Llama-based
- :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc.
- ✅︎
- ✅︎
* - :code:`Qwen2ForRewardModel`
- Qwen2-based
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
- ✅︎
- ✅︎
.. important::
For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
.. note::
As an interim measure, these models are supported in both offline and online inference via Embeddings API.
...
...
tests/conftest.py
View file @
13370712
...
...
@@ -263,7 +263,6 @@ class HfRunner:
dtype
:
str
=
"half"
,
*
,
model_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
is_embedding_model
:
bool
=
False
,
is_sentence_transformer
:
bool
=
False
,
is_cross_encoder
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
...
...
tests/models/embedding/language/test_embedding.py
View file @
13370712
...
...
@@ -4,6 +4,8 @@ Run `pytest tests/models/embedding/language/test_embedding.py`.
"""
import
pytest
from
vllm.config
import
PoolerConfig
from
..utils
import
check_embeddings_close
...
...
@@ -33,6 +35,9 @@ def test_models(
dtype
:
str
,
)
->
None
:
vllm_extra_kwargs
=
{}
if
model
==
"ssmits/Qwen2-7B-Instruct-embed-base"
:
vllm_extra_kwargs
[
"override_pooler_config"
]
=
\
PoolerConfig
(
pooling_type
=
"MEAN"
)
if
model
==
"Alibaba-NLP/gte-Qwen2-7B-instruct"
:
vllm_extra_kwargs
[
"hf_overrides"
]
=
{
"is_causal"
:
False
}
...
...
tests/models/test_registry.py
View file @
13370712
...
...
@@ -6,11 +6,8 @@ import torch.cuda
from
vllm.model_executor.models
import
(
is_embedding_model
,
is_text_generation_model
,
supports_multimodal
)
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.model_executor.models.registry
import
(
_CROSS_ENCODER_MODELS
,
_EMBEDDING_MODELS
,
_MULTIMODAL_MODELS
,
from
vllm.model_executor.models.adapters
import
as_embedding_model
from
vllm.model_executor.models.registry
import
(
_MULTIMODAL_MODELS
,
_SPECULATIVE_DECODING_MODELS
,
_TEXT_GENERATION_MODELS
,
ModelRegistry
)
...
...
@@ -26,18 +23,18 @@ def test_registry_imports(model_arch):
model_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
model_arch
)
if
model_arch
in
_SPECULATIVE_DECODING_MODELS
:
pass
# Ignore these models which do not have a unified format
else
:
assert
is_text_generation_model
(
model_cls
)
is
(
model_arch
in
_
TEXT_GENERATION
_MODELS
or
model_arch
in
_MULTIMODAL_MODELS
)
embedding
_
model
s
=
{
**
_EMBEDDING_MODELS
,
**
_CROSS_ENCODER_MODELS
}
assert
i
s_embedding_model
(
model_cls
)
is
(
model_arch
in
embedding_model
s
)
assert
supports_multimodal
(
model_cls
)
is
(
model_arch
in
_MULTIMODAL_MODELS
)
return
# Ignore these models which do not have a unified format
if
(
model_arch
in
_TEXT_GENERATION_MODELS
or
model_arch
in
_
MULTIMODAL
_MODELS
):
assert
is_text_generation_model
(
model_cls
)
# All vLLM models should be convertible to an
embedding
model
embed_model
=
a
s_embedding_model
(
model_cls
)
assert
is_
embedding_model
(
embed_model
)
if
model_arch
in
_MULTIMODAL_MODELS
:
assert
supports_multimodal
(
model_cls
)
@
fork_new_process_for_each_test
...
...
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
View file @
13370712
from
typing
import
List
,
Optional
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
from
vllm.attention
import
AttentionMetadata
from
vllm.model_executor.models.gemma2
import
Gemma2EmbeddingModel
from
vllm.sequence
import
IntermediateTensors
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.models.gemma2
import
Gemma2Model
from
vllm.model_executor.models.utils
import
WeightsMapper
,
maybe_prefix
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
class
MyGemma2Embedding
(
Gemma2EmbeddingModel
):
class
MyGemma2Embedding
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
model
=
Gemma2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
vllm_config
.
model_config
.
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
,
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
...
...
@@ -18,7 +39,7 @@ class MyGemma2Embedding(Gemma2EmbeddingModel):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
s
uper
().
forward
(
hidden_states
=
s
elf
.
model
(
input_ids
,
positions
,
kv_caches
,
...
...
@@ -32,3 +53,17 @@ class MyGemma2Embedding(Gemma2EmbeddingModel):
# Return all-zero embeddings
return
torch
.
zeros_like
(
hidden_states
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
weights
=
hf_to_vllm_mapper
.
apply
(
weights
)
weights
=
((
name
,
data
)
for
name
,
data
in
weights
if
not
name
.
startswith
(
"lm_head."
))
return
self
.
model
.
load_weights
(
weights
)
tests/test_config.py
View file @
13370712
...
...
@@ -26,8 +26,7 @@ def test_auto_task(model_id, expected_task):
@
pytest
.
mark
.
parametrize
((
"model_id"
,
"bad_task"
),
[
(
"facebook/opt-125m"
,
"embedding"
),
(
"intfloat/e5-mistral-7b-instruct"
,
"generate"
),
(
"Qwen/Qwen2.5-Math-RM-72B"
,
"generate"
),
])
def
test_incorrect_task
(
model_id
,
bad_task
):
with
pytest
.
raises
(
ValueError
,
match
=
r
"does not support the .* task"
):
...
...
vllm/config.py
View file @
13370712
...
...
@@ -370,6 +370,31 @@ class ModelConfig:
selected_task
=
next
(
iter
(
supported_tasks_lst
))
if
len
(
supported_tasks
)
>
1
:
suffix_to_preferred_task
:
List
[
Tuple
[
str
,
_Task
]]
=
[
# Hardcode the models that are exceptions
(
"AquilaModel"
,
"generate"
),
(
"ChatGLMModel"
,
"generate"
),
# Other models follow this pattern
(
"ForCausalLM"
,
"generate"
),
(
"ForConditionalGeneration"
,
"generate"
),
(
"ChatModel"
,
"generate"
),
(
"LMHeadModel"
,
"generate"
),
(
"EmbeddingModel"
,
"embedding"
),
(
"RewardModel"
,
"embedding"
),
(
"ForSequenceClassification"
,
"embedding"
),
]
info
,
arch
=
ModelRegistry
.
inspect_model_cls
(
architectures
)
for
suffix
,
pref_task
in
suffix_to_preferred_task
:
if
arch
.
endswith
(
suffix
)
and
pref_task
in
supported_tasks
:
selected_task
=
pref_task
break
else
:
if
(
arch
.
endswith
(
"Model"
)
and
info
.
architecture
.
endswith
(
"ForCausalLM"
)
and
"embedding"
in
supported_tasks
):
selected_task
=
"embedding"
logger
.
info
(
"This model supports multiple tasks: %s. "
"Defaulting to '%s'."
,
supported_tasks
,
selected_task
)
...
...
vllm/inputs/registry.py
View file @
13370712
...
...
@@ -11,8 +11,8 @@ from typing_extensions import TypeVar, assert_never
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
(
get_allowed_kwarg_only_overrides
,
print_warning_once
,
resolve_mm_processor_kwargs
)
from
vllm.utils
import
(
ClassRegistry
,
get_allowed_kwarg_only_overrides
,
print_warning_once
,
resolve_mm_processor_kwargs
)
from
.data
import
ProcessorInputs
,
SingletonInputs
from
.parse
import
is_encoder_decoder_inputs
...
...
@@ -136,12 +136,12 @@ class InputRegistry:
"""
def
__init__
(
self
)
->
None
:
self
.
_dummy_factories_by_model_type
:
Dict
[
Type
[
nn
.
Module
],
DummyDataFactory
]
=
{}
self
.
_dummy_encoder_factories_by_model_type
:
Dict
[
Type
[
nn
.
Module
]
,
DummyDataFactory
]
=
{}
self
.
_input_processors_by_model_type
:
Dict
[
Type
[
nn
.
Module
],
InputProcessor
]
=
{}
self
.
_dummy_factories_by_model_type
=
\
ClassRegistry
[
nn
.
Module
,
DummyDataFactory
]
()
self
.
_dummy_encoder_factories_by_model_type
=
\
ClassRegistry
[
nn
.
Module
,
DummyDataFactory
]
()
self
.
_input_processors_by_model_type
=
\
ClassRegistry
[
nn
.
Module
,
InputProcessor
]
()
def
_default_dummy_data_factory
(
self
,
...
...
vllm/model_executor/layers/pooler.py
View file @
13370712
...
...
@@ -60,9 +60,7 @@ class Pooler(nn.Module):
softmax
:
bool
,
step_tag_id
:
Optional
[
int
]
=
None
,
returned_token_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
Optional
[
"Pooler"
]:
if
pooler_config
is
None
:
return
None
)
->
"Pooler"
:
return
cls
(
pooling_type
=
PoolingType
[
pooler_config
.
pooling_type
]
if
pooler_config
.
pooling_type
is
not
None
else
pooling_type
,
...
...
vllm/model_executor/model_loader/loader.py
View file @
13370712
...
...
@@ -9,6 +9,7 @@ import itertools
import
json
import
math
import
os
import
warnings
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
...
...
@@ -97,22 +98,31 @@ def device_loading_context(module: torch.nn.Module,
logger
=
init_logger
(
__name__
)
def
_initialize_model
(
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
nn
.
Module
:
def
_initialize_model
(
vllm_config
:
VllmConfig
,
*
,
prefix
:
str
=
""
,
architectures
:
Optional
[
list
[
str
]]
=
None
,
)
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
model_config
=
vllm_config
.
model_config
model_class
,
_
=
get_model_architecture
(
model_config
)
model_class
,
_
=
get_model_architecture
(
model_config
,
architectures
=
architectures
)
signatures
=
inspect
.
signature
(
model_class
.
__init__
)
all_params
=
[
param
.
name
for
param
in
signatures
.
parameters
.
values
()]
if
"vllm_config"
in
all_params
and
"prefix"
in
all_params
:
# new-style model class
with
set_current_vllm_config
(
vllm_config
):
return
model_class
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
msg
=
(
"vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly."
)
logger
.
warning
(
msg
)
warnings
.
warn
(
msg
,
DeprecationWarning
,
stacklevel
=
2
)
logger
.
warning
(
"Trying to guess the arguments for old-style model class %s"
,
model_class
,
...
...
@@ -356,7 +366,7 @@ class DefaultModelLoader(BaseModelLoader):
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
loaded_weights
=
model
.
load_weights
(
self
.
_get_all_weights
(
model_config
,
model
))
# We only enable strict check for non-quanti
i
zed models
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if
model_config
.
quantization
is
None
and
loaded_weights
is
not
None
:
weights_not_loaded
=
weights_to_load
-
loaded_weights
...
...
vllm/model_executor/model_loader/utils.py
View file @
13370712
"""Utilities for selecting and loading models."""
import
contextlib
from
typing
import
Tuple
,
Type
from
typing
import
Optional
,
Tuple
,
Type
import
torch
from
torch
import
nn
from
vllm.config
import
ModelConfig
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models.adapters
import
as_embedding_model
@
contextlib
.
contextmanager
...
...
@@ -19,8 +20,13 @@ def set_default_torch_dtype(dtype: torch.dtype):
def
get_model_architecture
(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
model_config
:
ModelConfig
,
*
,
architectures
:
Optional
[
list
[
str
]]
=
None
,
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
if
architectures
is
None
:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported
=
[
...
...
@@ -32,7 +38,11 @@ def get_model_architecture(
and
"MixtralForCausalLM"
in
architectures
):
architectures
=
[
"QuantMixtralForCausalLM"
]
return
ModelRegistry
.
resolve_model_cls
(
architectures
)
model_cls
,
arch
=
ModelRegistry
.
resolve_model_cls
(
architectures
)
if
model_config
.
task
==
"embedding"
:
model_cls
=
as_embedding_model
(
model_cls
)
return
model_cls
,
arch
def
get_architecture_class_name
(
model_config
:
ModelConfig
)
->
str
:
...
...
vllm/model_executor/models/adapters.py
0 → 100644
View file @
13370712
from
collections.abc
import
Iterable
from
typing
import
Any
,
TypeVar
import
torch
import
torch.nn
as
nn
from
.interfaces_base
import
VllmModelForEmbedding
,
is_embedding_model
_T
=
TypeVar
(
"_T"
,
bound
=
type
[
nn
.
Module
])
def
as_embedding_model
(
cls
:
_T
)
->
_T
:
"""Subclass an existing vLLM model to support embeddings."""
# Avoid modifying existing embedding models
if
is_embedding_model
(
cls
):
return
cls
# Lazy import
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.pooler
import
(
Pooler
,
PoolerOutput
,
PoolingType
)
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
class
ModelForEmbedding
(
cls
,
VllmModelForEmbedding
):
def
__init__
(
self
,
*
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
,
**
kwargs
:
Any
,
)
->
None
:
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
# These are not used in embedding models
for
attr
in
(
"lm_head"
,
"logits_processor"
):
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
# If the model already defines a pooler instance, don't overwrite it
if
not
getattr
(
self
,
"_pooler"
,
None
):
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
,
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
# TODO: Support uninitialized params tracking
# We have deleted this attribute, so don't load it
weights
=
((
name
,
data
)
for
name
,
data
in
weights
if
not
name
.
startswith
(
"lm_head."
))
# If `*ForCausalLM` defines `load_weights` on the inner model
# and there are no other inner modules with parameters,
# we support loading from both `*Model` and `*ForCausalLM`
if
hasattr
(
self
,
"model"
)
and
hasattr
(
self
.
model
,
"load_weights"
):
# Whether only `self.model` contains parameters
model_is_only_param
=
all
(
name
==
"model"
or
next
(
child
.
parameters
(),
None
)
is
None
for
name
,
child
in
self
.
named_children
())
if
model_is_only_param
:
mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
weights
=
mapper
.
apply
(
weights
)
self
.
model
.
load_weights
(
weights
)
return
# For most other models
if
hasattr
(
cls
,
"load_weights"
):
cls
.
load_weights
(
self
,
weights
)
# type: ignore
# Fallback
else
:
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
)
ModelForEmbedding
.
__name__
=
cls
.
__name__
\
.
removesuffix
(
"ForCausalLM"
)
\
.
removesuffix
(
"ForConditionalGeneration"
)
\
.
removesuffix
(
"ChatModel"
)
\
.
removesuffix
(
"LMHeadModel"
)
+
"ForEmbedding"
return
ModelForEmbedding
# type: ignore
vllm/model_executor/models/blip2.py
View file @
13370712
...
...
@@ -512,9 +512,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/gemma2.py
View file @
13370712
...
...
@@ -30,19 +30,17 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
extract_layer_index
,
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -455,55 +453,3 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
class
Gemma2EmbeddingModel
(
nn
.
Module
,
SupportsPP
):
"""
A model that uses Gemma2 with additional embedding functionalities.
This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
model
=
Gemma2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
vllm_config
.
model_config
.
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
return
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
weights
=
hf_to_vllm_mapper
.
apply
(
weights
)
weights
=
((
name
,
data
)
for
name
,
data
in
weights
if
not
name
.
startswith
(
"lm_head."
))
self
.
model
.
load_weights
(
weights
)
vllm/model_executor/models/internvl.py
View file @
13370712
...
...
@@ -474,9 +474,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
mlp1
=
self
.
_init_mlp1
(
config
)
...
...
vllm/model_executor/models/llama.py
View file @
13370712
...
...
@@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
get_compressed_tensors_cache_scale
)
...
...
@@ -47,14 +46,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
WeightsMapper
,
extract_layer_index
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -511,11 +509,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
model
=
self
.
_init_model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
self
.
model
=
self
.
_init_model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
...
...
@@ -544,13 +543,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
sampler
=
get_sampler
()
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
STEP
,
normalize
=
False
,
softmax
=
False
)
def
_init_model
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
return
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
...
...
@@ -581,14 +576,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata
)
return
logits
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
logits
=
self
.
compute_logits
(
hidden_states
,
None
)
return
self
.
_pooler
(
logits
,
pooling_metadata
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
...
...
@@ -639,78 +626,3 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
name
=
name
.
replace
(
item
,
mapping
[
item
])
return
name
,
loaded_weight
class
LlamaEmbeddingModel
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
"""
A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
"embed_tokens"
]
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
}
embedding_padding_modules
=
[]
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
model
=
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
return
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
weights
=
hf_to_vllm_mapper
.
apply
(
weights
)
weights
=
((
name
,
data
)
for
name
,
data
in
weights
if
not
name
.
startswith
(
"lm_head."
))
self
.
model
.
load_weights
(
weights
)
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
self
.
model
.
load_kv_cache_scales
(
quantization_param_path
)
# LRUCacheWorkerLoRAManager instantiation requires model config.
@
property
def
config
(
self
):
return
self
.
model
.
config
vllm/model_executor/models/llava.py
View file @
13370712
...
...
@@ -319,9 +319,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/llava_next.py
View file @
13370712
...
...
@@ -14,13 +14,11 @@ from vllm.attention import AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
)
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
NestedTensors
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
...
...
@@ -286,7 +284,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
vision_feature_layer
=
config
.
vision_feature_layer
...
...
@@ -321,17 +318,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
# The same model class supports both language generation and embedding
# because the architecture name is the same
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
)
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
@@ -678,13 +669,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
...
...
vllm/model_executor/models/llava_next_video.py
View file @
13370712
...
...
@@ -275,9 +275,10 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
model
.
make_empty_intermediate_tensors
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment