Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
13370712
Unverified
Commit
13370712
authored
Dec 01, 2024
by
Cyrus Leung
Committed by
GitHub
Dec 01, 2024
Browse files
[Model] Replace embedding models with pooling adapter (#10769)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
7e4bbda5
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
262 additions
and
227 deletions
+262
-227
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-2
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+14
-1
tests/conftest.py
tests/conftest.py
+0
-1
tests/models/embedding/language/test_embedding.py
tests/models/embedding/language/test_embedding.py
+5
-0
tests/models/test_registry.py
tests/models/test_registry.py
+14
-17
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
...dd_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
+40
-5
tests/test_config.py
tests/test_config.py
+1
-2
vllm/config.py
vllm/config.py
+25
-0
vllm/inputs/registry.py
vllm/inputs/registry.py
+8
-8
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+1
-3
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+14
-4
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+14
-4
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+98
-0
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+3
-2
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+2
-56
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+3
-2
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+7
-95
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+3
-2
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+5
-21
vllm/model_executor/models/llava_next_video.py
vllm/model_executor/models/llava_next_video.py
+3
-2
No files found.
.buildkite/test-pipeline.yaml
View file @
13370712
...
@@ -334,7 +334,6 @@ steps:
...
@@ -334,7 +334,6 @@ steps:
commands
:
commands
:
-
pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
-
pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
-
pytest -v -s models/embedding/language -m core_model
-
pytest -v -s models/embedding/language -m core_model
-
pytest -v -s models/embedding/vision_language -m core_model
-
label
:
Language Models Test (Extended)
# 50min
-
label
:
Language Models Test (Extended)
# 50min
optional
:
true
optional
:
true
...
@@ -346,7 +345,6 @@ steps:
...
@@ -346,7 +345,6 @@ steps:
commands
:
commands
:
-
pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
-
pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
-
pytest -v -s models/embedding/language -m 'not core_model'
-
pytest -v -s models/embedding/language -m 'not core_model'
-
pytest -v -s models/embedding/vision_language -m 'not core_model'
-
label
:
Multi-Modal Models Test (Standard)
# 26min
-
label
:
Multi-Modal Models Test (Standard)
# 26min
#mirror_hardwares: [amd]
#mirror_hardwares: [amd]
...
@@ -359,6 +357,7 @@ steps:
...
@@ -359,6 +357,7 @@ steps:
commands
:
commands
:
-
pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
-
pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
-
pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
-
pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
-
pytest -v -s models/embedding/vision_language -m core_model
-
pytest -v -s models/encoder_decoder/language -m core_model
-
pytest -v -s models/encoder_decoder/language -m core_model
-
pytest -v -s models/encoder_decoder/vision_language -m core_model
-
pytest -v -s models/encoder_decoder/vision_language -m core_model
...
@@ -376,6 +375,7 @@ steps:
...
@@ -376,6 +375,7 @@ steps:
# https://github.com/huggingface/transformers/issues/34307
# https://github.com/huggingface/transformers/issues/34307
-
pytest -v -s models/decoder_only/vision_language/test_phi3v.py
-
pytest -v -s models/decoder_only/vision_language/test_phi3v.py
-
pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
-
pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
-
pytest -v -s models/embedding/vision_language -m 'not core_model'
-
pytest -v -s models/encoder_decoder/language -m 'not core_model'
-
pytest -v -s models/encoder_decoder/language -m 'not core_model'
-
pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
-
pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
...
...
docs/source/models/supported_models.rst
View file @
13370712
...
@@ -357,7 +357,7 @@ Text Embedding
...
@@ -357,7 +357,7 @@ Text Embedding
- ✅︎
- ✅︎
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
- Qwen2-based
- Qwen2-based
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`
(see note)
, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- ✅︎
- ✅︎
- ✅︎
- ✅︎
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
...
@@ -378,6 +378,10 @@ Text Embedding
...
@@ -378,6 +378,10 @@ Text Embedding
.. tip::
.. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`.
You can override the model's pooling method by passing :code:`--override-pooler-config`.
.. note::
:code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`.
.. note::
.. note::
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.
You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.
...
@@ -397,12 +401,21 @@ Reward Modeling
...
@@ -397,12 +401,21 @@ Reward Modeling
- Example HF Models
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
- :ref:`PP <distributed_serving>`
* - :code:`LlamaForCausalLM`
- Llama-based
- :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc.
- ✅︎
- ✅︎
* - :code:`Qwen2ForRewardModel`
* - :code:`Qwen2ForRewardModel`
- Qwen2-based
- Qwen2-based
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
- ✅︎
- ✅︎
- ✅︎
- ✅︎
.. important::
For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
.. note::
.. note::
As an interim measure, these models are supported in both offline and online inference via Embeddings API.
As an interim measure, these models are supported in both offline and online inference via Embeddings API.
...
...
tests/conftest.py
View file @
13370712
...
@@ -263,7 +263,6 @@ class HfRunner:
...
@@ -263,7 +263,6 @@ class HfRunner:
dtype
:
str
=
"half"
,
dtype
:
str
=
"half"
,
*
,
*
,
model_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
model_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
is_embedding_model
:
bool
=
False
,
is_sentence_transformer
:
bool
=
False
,
is_sentence_transformer
:
bool
=
False
,
is_cross_encoder
:
bool
=
False
,
is_cross_encoder
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
...
...
tests/models/embedding/language/test_embedding.py
View file @
13370712
...
@@ -4,6 +4,8 @@ Run `pytest tests/models/embedding/language/test_embedding.py`.
...
@@ -4,6 +4,8 @@ Run `pytest tests/models/embedding/language/test_embedding.py`.
"""
"""
import
pytest
import
pytest
from
vllm.config
import
PoolerConfig
from
..utils
import
check_embeddings_close
from
..utils
import
check_embeddings_close
...
@@ -33,6 +35,9 @@ def test_models(
...
@@ -33,6 +35,9 @@ def test_models(
dtype
:
str
,
dtype
:
str
,
)
->
None
:
)
->
None
:
vllm_extra_kwargs
=
{}
vllm_extra_kwargs
=
{}
if
model
==
"ssmits/Qwen2-7B-Instruct-embed-base"
:
vllm_extra_kwargs
[
"override_pooler_config"
]
=
\
PoolerConfig
(
pooling_type
=
"MEAN"
)
if
model
==
"Alibaba-NLP/gte-Qwen2-7B-instruct"
:
if
model
==
"Alibaba-NLP/gte-Qwen2-7B-instruct"
:
vllm_extra_kwargs
[
"hf_overrides"
]
=
{
"is_causal"
:
False
}
vllm_extra_kwargs
[
"hf_overrides"
]
=
{
"is_causal"
:
False
}
...
...
tests/models/test_registry.py
View file @
13370712
...
@@ -6,11 +6,8 @@ import torch.cuda
...
@@ -6,11 +6,8 @@ import torch.cuda
from
vllm.model_executor.models
import
(
is_embedding_model
,
from
vllm.model_executor.models
import
(
is_embedding_model
,
is_text_generation_model
,
is_text_generation_model
,
supports_multimodal
)
supports_multimodal
)
# yapf conflicts with isort for this block
from
vllm.model_executor.models.adapters
import
as_embedding_model
# yapf: disable
from
vllm.model_executor.models.registry
import
(
_MULTIMODAL_MODELS
,
from
vllm.model_executor.models.registry
import
(
_CROSS_ENCODER_MODELS
,
_EMBEDDING_MODELS
,
_MULTIMODAL_MODELS
,
_SPECULATIVE_DECODING_MODELS
,
_SPECULATIVE_DECODING_MODELS
,
_TEXT_GENERATION_MODELS
,
_TEXT_GENERATION_MODELS
,
ModelRegistry
)
ModelRegistry
)
...
@@ -26,18 +23,18 @@ def test_registry_imports(model_arch):
...
@@ -26,18 +23,18 @@ def test_registry_imports(model_arch):
model_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
model_arch
)
model_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
model_arch
)
if
model_arch
in
_SPECULATIVE_DECODING_MODELS
:
if
model_arch
in
_SPECULATIVE_DECODING_MODELS
:
pass
# Ignore these models which do not have a unified format
return
# Ignore these models which do not have a unified format
else
:
assert
is_text_generation_model
(
model_cls
)
is
(
if
(
model_arch
in
_TEXT_GENERATION_MODELS
model_arch
in
_
TEXT_GENERATION
_MODELS
or
model_arch
in
_
MULTIMODAL
_MODELS
):
or
model_arch
in
_MULTIMODAL_MODELS
)
assert
is_text_generation_model
(
model_cls
)
embedding
_
model
s
=
{
**
_EMBEDDING_MODELS
,
**
_CROSS_ENCODER_MODELS
}
# All vLLM models should be convertible to an
embedding
model
assert
i
s_embedding_model
(
model_cls
)
is
(
model_arch
embed_model
=
a
s_embedding_model
(
model_cls
)
in
embedding_model
s
)
assert
is_
embedding_model
(
embed_model
)
assert
supports_multimodal
(
model_cls
)
is
(
model_arch
if
model_arch
in
_MULTIMODAL_MODELS
:
in
_MULTIMODAL_MODELS
)
assert
supports_multimodal
(
model_cls
)
@
fork_new_process_for_each_test
@
fork_new_process_for_each_test
...
...
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
View file @
13370712
from
typing
import
List
,
Optional
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.model_executor.models.gemma2
import
Gemma2EmbeddingModel
from
vllm.config
import
VllmConfig
from
vllm.sequence
import
IntermediateTensors
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.models.gemma2
import
Gemma2Model
from
vllm.model_executor.models.utils
import
WeightsMapper
,
maybe_prefix
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
class
MyGemma2Embedding
(
Gemma2EmbeddingModel
):
class
MyGemma2Embedding
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
model
=
Gemma2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
vllm_config
.
model_config
.
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
,
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -18,7 +39,7 @@ class MyGemma2Embedding(Gemma2EmbeddingModel):
...
@@ -18,7 +39,7 @@ class MyGemma2Embedding(Gemma2EmbeddingModel):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
s
uper
().
forward
(
hidden_states
=
s
elf
.
model
(
input_ids
,
input_ids
,
positions
,
positions
,
kv_caches
,
kv_caches
,
...
@@ -32,3 +53,17 @@ class MyGemma2Embedding(Gemma2EmbeddingModel):
...
@@ -32,3 +53,17 @@ class MyGemma2Embedding(Gemma2EmbeddingModel):
# Return all-zero embeddings
# Return all-zero embeddings
return
torch
.
zeros_like
(
hidden_states
)
return
torch
.
zeros_like
(
hidden_states
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
weights
=
hf_to_vllm_mapper
.
apply
(
weights
)
weights
=
((
name
,
data
)
for
name
,
data
in
weights
if
not
name
.
startswith
(
"lm_head."
))
return
self
.
model
.
load_weights
(
weights
)
tests/test_config.py
View file @
13370712
...
@@ -26,8 +26,7 @@ def test_auto_task(model_id, expected_task):
...
@@ -26,8 +26,7 @@ def test_auto_task(model_id, expected_task):
@
pytest
.
mark
.
parametrize
((
"model_id"
,
"bad_task"
),
[
@
pytest
.
mark
.
parametrize
((
"model_id"
,
"bad_task"
),
[
(
"facebook/opt-125m"
,
"embedding"
),
(
"Qwen/Qwen2.5-Math-RM-72B"
,
"generate"
),
(
"intfloat/e5-mistral-7b-instruct"
,
"generate"
),
])
])
def
test_incorrect_task
(
model_id
,
bad_task
):
def
test_incorrect_task
(
model_id
,
bad_task
):
with
pytest
.
raises
(
ValueError
,
match
=
r
"does not support the .* task"
):
with
pytest
.
raises
(
ValueError
,
match
=
r
"does not support the .* task"
):
...
...
vllm/config.py
View file @
13370712
...
@@ -370,6 +370,31 @@ class ModelConfig:
...
@@ -370,6 +370,31 @@ class ModelConfig:
selected_task
=
next
(
iter
(
supported_tasks_lst
))
selected_task
=
next
(
iter
(
supported_tasks_lst
))
if
len
(
supported_tasks
)
>
1
:
if
len
(
supported_tasks
)
>
1
:
suffix_to_preferred_task
:
List
[
Tuple
[
str
,
_Task
]]
=
[
# Hardcode the models that are exceptions
(
"AquilaModel"
,
"generate"
),
(
"ChatGLMModel"
,
"generate"
),
# Other models follow this pattern
(
"ForCausalLM"
,
"generate"
),
(
"ForConditionalGeneration"
,
"generate"
),
(
"ChatModel"
,
"generate"
),
(
"LMHeadModel"
,
"generate"
),
(
"EmbeddingModel"
,
"embedding"
),
(
"RewardModel"
,
"embedding"
),
(
"ForSequenceClassification"
,
"embedding"
),
]
info
,
arch
=
ModelRegistry
.
inspect_model_cls
(
architectures
)
for
suffix
,
pref_task
in
suffix_to_preferred_task
:
if
arch
.
endswith
(
suffix
)
and
pref_task
in
supported_tasks
:
selected_task
=
pref_task
break
else
:
if
(
arch
.
endswith
(
"Model"
)
and
info
.
architecture
.
endswith
(
"ForCausalLM"
)
and
"embedding"
in
supported_tasks
):
selected_task
=
"embedding"
logger
.
info
(
logger
.
info
(
"This model supports multiple tasks: %s. "
"This model supports multiple tasks: %s. "
"Defaulting to '%s'."
,
supported_tasks
,
selected_task
)
"Defaulting to '%s'."
,
supported_tasks
,
selected_task
)
...
...
vllm/inputs/registry.py
View file @
13370712
...
@@ -11,8 +11,8 @@ from typing_extensions import TypeVar, assert_never
...
@@ -11,8 +11,8 @@ from typing_extensions import TypeVar, assert_never
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
(
get_allowed_kwarg_only_overrides
,
print_warning_once
,
from
vllm.utils
import
(
ClassRegistry
,
get_allowed_kwarg_only_overrides
,
resolve_mm_processor_kwargs
)
print_warning_once
,
resolve_mm_processor_kwargs
)
from
.data
import
ProcessorInputs
,
SingletonInputs
from
.data
import
ProcessorInputs
,
SingletonInputs
from
.parse
import
is_encoder_decoder_inputs
from
.parse
import
is_encoder_decoder_inputs
...
@@ -136,12 +136,12 @@ class InputRegistry:
...
@@ -136,12 +136,12 @@ class InputRegistry:
"""
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
self
.
_dummy_factories_by_model_type
:
Dict
[
Type
[
nn
.
Module
],
self
.
_dummy_factories_by_model_type
=
\
DummyDataFactory
]
=
{}
ClassRegistry
[
nn
.
Module
,
DummyDataFactory
]
()
self
.
_dummy_encoder_factories_by_model_type
:
Dict
[
self
.
_dummy_encoder_factories_by_model_type
=
\
Type
[
nn
.
Module
]
,
DummyDataFactory
]
=
{}
ClassRegistry
[
nn
.
Module
,
DummyDataFactory
]
()
self
.
_input_processors_by_model_type
:
Dict
[
Type
[
nn
.
Module
],
self
.
_input_processors_by_model_type
=
\
InputProcessor
]
=
{}
ClassRegistry
[
nn
.
Module
,
InputProcessor
]
()
def
_default_dummy_data_factory
(
def
_default_dummy_data_factory
(
self
,
self
,
...
...
vllm/model_executor/layers/pooler.py
View file @
13370712
...
@@ -60,9 +60,7 @@ class Pooler(nn.Module):
...
@@ -60,9 +60,7 @@ class Pooler(nn.Module):
softmax
:
bool
,
softmax
:
bool
,
step_tag_id
:
Optional
[
int
]
=
None
,
step_tag_id
:
Optional
[
int
]
=
None
,
returned_token_ids
:
Optional
[
List
[
int
]]
=
None
,
returned_token_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
Optional
[
"Pooler"
]:
)
->
"Pooler"
:
if
pooler_config
is
None
:
return
None
return
cls
(
return
cls
(
pooling_type
=
PoolingType
[
pooler_config
.
pooling_type
]
pooling_type
=
PoolingType
[
pooler_config
.
pooling_type
]
if
pooler_config
.
pooling_type
is
not
None
else
pooling_type
,
if
pooler_config
.
pooling_type
is
not
None
else
pooling_type
,
...
...
vllm/model_executor/model_loader/loader.py
View file @
13370712
...
@@ -9,6 +9,7 @@ import itertools
...
@@ -9,6 +9,7 @@ import itertools
import
json
import
json
import
math
import
math
import
os
import
os
import
warnings
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
from
typing
import
Any
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
...
@@ -97,22 +98,31 @@ def device_loading_context(module: torch.nn.Module,
...
@@ -97,22 +98,31 @@ def device_loading_context(module: torch.nn.Module,
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
_initialize_model
(
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
nn
.
Module
:
def
_initialize_model
(
vllm_config
:
VllmConfig
,
*
,
prefix
:
str
=
""
,
architectures
:
Optional
[
list
[
str
]]
=
None
,
)
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
"""Initialize a model with the given configurations."""
model_config
=
vllm_config
.
model_config
model_config
=
vllm_config
.
model_config
model_class
,
_
=
get_model_architecture
(
model_config
)
model_class
,
_
=
get_model_architecture
(
model_config
,
architectures
=
architectures
)
signatures
=
inspect
.
signature
(
model_class
.
__init__
)
signatures
=
inspect
.
signature
(
model_class
.
__init__
)
all_params
=
[
param
.
name
for
param
in
signatures
.
parameters
.
values
()]
all_params
=
[
param
.
name
for
param
in
signatures
.
parameters
.
values
()]
if
"vllm_config"
in
all_params
and
"prefix"
in
all_params
:
if
"vllm_config"
in
all_params
and
"prefix"
in
all_params
:
# new-style model class
# new-style model class
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
return
model_class
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
return
model_class
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
msg
=
(
"vLLM model class should accept `vllm_config` and `prefix` as "
msg
=
(
"vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly."
)
"for the design and update the model class accordingly."
)
logger
.
warning
(
msg
)
warnings
.
warn
(
msg
,
DeprecationWarning
,
stacklevel
=
2
)
logger
.
warning
(
logger
.
warning
(
"Trying to guess the arguments for old-style model class %s"
,
"Trying to guess the arguments for old-style model class %s"
,
model_class
,
model_class
,
...
@@ -356,7 +366,7 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -356,7 +366,7 @@ class DefaultModelLoader(BaseModelLoader):
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
loaded_weights
=
model
.
load_weights
(
loaded_weights
=
model
.
load_weights
(
self
.
_get_all_weights
(
model_config
,
model
))
self
.
_get_all_weights
(
model_config
,
model
))
# We only enable strict check for non-quanti
i
zed models
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
# that have loaded weights tracking currently.
if
model_config
.
quantization
is
None
and
loaded_weights
is
not
None
:
if
model_config
.
quantization
is
None
and
loaded_weights
is
not
None
:
weights_not_loaded
=
weights_to_load
-
loaded_weights
weights_not_loaded
=
weights_to_load
-
loaded_weights
...
...
vllm/model_executor/model_loader/utils.py
View file @
13370712
"""Utilities for selecting and loading models."""
"""Utilities for selecting and loading models."""
import
contextlib
import
contextlib
from
typing
import
Tuple
,
Type
from
typing
import
Optional
,
Tuple
,
Type
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models.adapters
import
as_embedding_model
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
...
@@ -19,8 +20,13 @@ def set_default_torch_dtype(dtype: torch.dtype):
...
@@ -19,8 +20,13 @@ def set_default_torch_dtype(dtype: torch.dtype):
def
get_model_architecture
(
def
get_model_architecture
(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
model_config
:
ModelConfig
,
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
*
,
architectures
:
Optional
[
list
[
str
]]
=
None
,
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
if
architectures
is
None
:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
# Special handling for quantized Mixtral.
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported
=
[
mixtral_supported
=
[
...
@@ -32,7 +38,11 @@ def get_model_architecture(
...
@@ -32,7 +38,11 @@ def get_model_architecture(
and
"MixtralForCausalLM"
in
architectures
):
and
"MixtralForCausalLM"
in
architectures
):
architectures
=
[
"QuantMixtralForCausalLM"
]
architectures
=
[
"QuantMixtralForCausalLM"
]
return
ModelRegistry
.
resolve_model_cls
(
architectures
)
model_cls
,
arch
=
ModelRegistry
.
resolve_model_cls
(
architectures
)
if
model_config
.
task
==
"embedding"
:
model_cls
=
as_embedding_model
(
model_cls
)
return
model_cls
,
arch
def
get_architecture_class_name
(
model_config
:
ModelConfig
)
->
str
:
def
get_architecture_class_name
(
model_config
:
ModelConfig
)
->
str
:
...
...
vllm/model_executor/models/adapters.py
0 → 100644
View file @
13370712
from
collections.abc
import
Iterable
from
typing
import
Any
,
TypeVar
import
torch
import
torch.nn
as
nn
from
.interfaces_base
import
VllmModelForEmbedding
,
is_embedding_model
_T
=
TypeVar
(
"_T"
,
bound
=
type
[
nn
.
Module
])
def
as_embedding_model
(
cls
:
_T
)
->
_T
:
"""Subclass an existing vLLM model to support embeddings."""
# Avoid modifying existing embedding models
if
is_embedding_model
(
cls
):
return
cls
# Lazy import
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.pooler
import
(
Pooler
,
PoolerOutput
,
PoolingType
)
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
class
ModelForEmbedding
(
cls
,
VllmModelForEmbedding
):
def
__init__
(
self
,
*
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
,
**
kwargs
:
Any
,
)
->
None
:
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
# These are not used in embedding models
for
attr
in
(
"lm_head"
,
"logits_processor"
):
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
# If the model already defines a pooler instance, don't overwrite it
if
not
getattr
(
self
,
"_pooler"
,
None
):
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
,
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
# TODO: Support uninitialized params tracking
# We have deleted this attribute, so don't load it
weights
=
((
name
,
data
)
for
name
,
data
in
weights
if
not
name
.
startswith
(
"lm_head."
))
# If `*ForCausalLM` defines `load_weights` on the inner model
# and there are no other inner modules with parameters,
# we support loading from both `*Model` and `*ForCausalLM`
if
hasattr
(
self
,
"model"
)
and
hasattr
(
self
.
model
,
"load_weights"
):
# Whether only `self.model` contains parameters
model_is_only_param
=
all
(
name
==
"model"
or
next
(
child
.
parameters
(),
None
)
is
None
for
name
,
child
in
self
.
named_children
())
if
model_is_only_param
:
mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
weights
=
mapper
.
apply
(
weights
)
self
.
model
.
load_weights
(
weights
)
return
# For most other models
if
hasattr
(
cls
,
"load_weights"
):
cls
.
load_weights
(
self
,
weights
)
# type: ignore
# Fallback
else
:
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
)
ModelForEmbedding
.
__name__
=
cls
.
__name__
\
.
removesuffix
(
"ForCausalLM"
)
\
.
removesuffix
(
"ForConditionalGeneration"
)
\
.
removesuffix
(
"ChatModel"
)
\
.
removesuffix
(
"LMHeadModel"
)
+
"ForEmbedding"
return
ModelForEmbedding
# type: ignore
vllm/model_executor/models/blip2.py
View file @
13370712
...
@@ -512,9 +512,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -512,9 +512,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
)
)
self
.
language_model
=
init_vllm_registered_model
(
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/gemma2.py
View file @
13370712
...
@@ -30,19 +30,17 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -30,19 +30,17 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
extract_layer_index
,
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
is_pp_missing_parameter
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -455,55 +453,3 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -455,55 +453,3 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if
self
.
config
.
tie_word_embeddings
else
None
),
if
self
.
config
.
tie_word_embeddings
else
None
),
)
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
class
Gemma2EmbeddingModel
(
nn
.
Module
,
SupportsPP
):
"""
A model that uses Gemma2 with additional embedding functionalities.
This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
model
=
Gemma2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
vllm_config
.
model_config
.
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
return
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
weights
=
hf_to_vllm_mapper
.
apply
(
weights
)
weights
=
((
name
,
data
)
for
name
,
data
in
weights
if
not
name
.
startswith
(
"lm_head."
))
self
.
model
.
load_weights
(
weights
)
vllm/model_executor/models/internvl.py
View file @
13370712
...
@@ -474,9 +474,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -474,9 +474,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
)
)
self
.
language_model
=
init_vllm_registered_model
(
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
mlp1
=
self
.
_init_mlp1
(
config
)
self
.
mlp1
=
self
.
_init_mlp1
(
config
)
...
...
vllm/model_executor/models/llama.py
View file @
13370712
...
@@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
get_compressed_tensors_cache_scale
)
get_compressed_tensors_cache_scale
)
...
@@ -47,14 +46,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -47,14 +46,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
WeightsMapper
,
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
extract_layer_index
,
extract_layer_index
,
is_pp_missing_parameter
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -511,11 +509,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -511,11 +509,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
lora_config
=
vllm_config
.
lora_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
config
=
config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
model
=
self
.
_init_model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
self
.
model
=
self
.
_init_model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
if
lora_config
:
...
@@ -544,13 +543,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -544,13 +543,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
sampler
=
get_sampler
()
self
.
sampler
=
get_sampler
()
else
:
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
lm_head
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
STEP
,
normalize
=
False
,
softmax
=
False
)
def
_init_model
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
_init_model
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
return
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
return
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
...
@@ -581,14 +576,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -581,14 +576,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
logits
=
self
.
compute_logits
(
hidden_states
,
None
)
return
self
.
_pooler
(
logits
,
pooling_metadata
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
Optional
[
SamplerOutput
]:
sampling_metadata
:
SamplingMetadata
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
...
@@ -639,78 +626,3 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -639,78 +626,3 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
name
=
name
.
replace
(
item
,
mapping
[
item
])
name
=
name
.
replace
(
item
,
mapping
[
item
])
return
name
,
loaded_weight
return
name
,
loaded_weight
class
LlamaEmbeddingModel
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
"""
A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
"embed_tokens"
]
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
}
embedding_padding_modules
=
[]
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
model
=
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
return
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
weights
=
hf_to_vllm_mapper
.
apply
(
weights
)
weights
=
((
name
,
data
)
for
name
,
data
in
weights
if
not
name
.
startswith
(
"lm_head."
))
self
.
model
.
load_weights
(
weights
)
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
self
.
model
.
load_kv_cache_scales
(
quantization_param_path
)
# LRUCacheWorkerLoRAManager instantiation requires model config.
@
property
def
config
(
self
):
return
self
.
model
.
config
vllm/model_executor/models/llava.py
View file @
13370712
...
@@ -319,9 +319,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -319,9 +319,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
projector_hidden_act
=
config
.
projector_hidden_act
)
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
language_model
=
init_vllm_registered_model
(
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/llava_next.py
View file @
13370712
...
@@ -14,13 +14,11 @@ from vllm.attention import AttentionMetadata
...
@@ -14,13 +14,11 @@ from vllm.attention import AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
)
InputContext
)
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
NestedTensors
from
vllm.multimodal.inputs
import
NestedTensors
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
...
@@ -286,7 +284,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -286,7 +284,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
vision_feature_layer
=
config
.
vision_feature_layer
vision_feature_layer
=
config
.
vision_feature_layer
...
@@ -321,17 +318,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -321,17 +318,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
projector_hidden_act
=
config
.
projector_hidden_act
)
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
language_model
=
init_vllm_registered_model
(
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
# The same model class supports both language generation and embedding
)
# because the architecture name is the same
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
self
.
language_model
.
make_empty_intermediate_tensors
)
...
@@ -678,13 +669,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -678,13 +669,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
...
...
vllm/model_executor/models/llava_next_video.py
View file @
13370712
...
@@ -275,9 +275,10 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -275,9 +275,10 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
)
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
language_model
=
init_vllm_registered_model
(
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
model
.
make_empty_intermediate_tensors
)
self
.
language_model
.
model
.
make_empty_intermediate_tensors
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment