Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
051eaf6d
Unverified
Commit
051eaf6d
authored
Oct 19, 2024
by
Cyrus Leung
Committed by
GitHub
Oct 18, 2024
Browse files
[Model] Add user-configurable task for models that support both generation and embedding (#9424)
parent
7dbe738d
Changes
33
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
233 additions
and
71 deletions
+233
-71
tests/multimodal/test_processor_kwargs.py
tests/multimodal/test_processor_kwargs.py
+6
-1
tests/quantization/test_configs.py
tests/quantization/test_configs.py
+2
-1
tests/test_config.py
tests/test_config.py
+50
-7
tests/test_utils.py
tests/test_utils.py
+6
-6
tests/utils.py
tests/utils.py
+4
-4
vllm/config.py
vllm/config.py
+54
-23
vllm/core/scheduler.py
vllm/core/scheduler.py
+1
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+14
-3
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-5
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+44
-12
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+2
-1
vllm/utils.py
vllm/utils.py
+47
-3
vllm/worker/worker.py
vllm/worker/worker.py
+1
-4
No files found.
tests/multimodal/test_processor_kwargs.py
View file @
051eaf6d
...
...
@@ -221,6 +221,7 @@ def test_max_tokens_kwarg_overrides(num_crops):
expected_seq_count
=
DEFAULT_NUM_CROPS
if
num_crops
is
None
else
num_crops
ctx
=
build_model_context
(
MULTIMODAL_MODEL_ID
,
task
=
"generate"
,
trust_remote_code
=
True
,
mm_processor_kwargs
=
mm_processor_kwargs
,
limit_mm_per_prompt
=
{
"image"
:
1
})
...
...
@@ -256,6 +257,7 @@ def test_max_tokens_kwarg_overrides(num_crops):
def
test_max_tokens_with_sad_kwarg_overrides
(
mm_processor_kwargs
):
"""Ensure that max token calcs filters out invalid mm_processor_kwargs"""
ctx
=
build_model_context
(
MULTIMODAL_MODEL_ID
,
task
=
"generate"
,
trust_remote_code
=
True
,
mm_processor_kwargs
=
mm_processor_kwargs
,
limit_mm_per_prompt
=
{
"image"
:
1
})
...
...
@@ -278,12 +280,13 @@ def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
### Test overrides for the mapper
@
pytest
.
mark
.
parametrize
(
"num_crops"
,
[
DEFAULT_NUM_CROPS
,
NUM_CROPS_OVERRIDE
])
def
test_default_mapper_with_process
e
r_kwargs
(
image_assets
,
num_crops
):
def
test_default_mapper_with_process
o
r_kwargs
(
image_assets
,
num_crops
):
"""Ensure that the mapper processor kwargs can fall back to HF models."""
# NOTE - we don't validate bad inputs for the default mapper, because it's
# through the automodel interface in transformers, so we can't easily
# inspect what kwargs are or are not allowed.
ctx
=
build_model_context
(
MULTIMODAL_MODEL_ID
,
task
=
"generate"
,
trust_remote_code
=
True
,
mm_processor_kwargs
=
{
"num_crops"
:
num_crops
},
limit_mm_per_prompt
=
{
"image"
:
1
})
...
...
@@ -311,6 +314,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops,
init_num_crops
,
inference_num_crops
)
ctx
=
build_model_context
(
MULTIMODAL_MODEL_ID
,
task
=
"generate"
,
trust_remote_code
=
True
,
mm_processor_kwargs
=
init_kwargs
,
limit_mm_per_prompt
=
{
"image"
:
1
})
...
...
@@ -348,6 +352,7 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
# Should filter out the init time kwargs
ctx
=
build_model_context
(
MULTIMODAL_MODEL_ID
,
task
=
"generate"
,
trust_remote_code
=
True
,
mm_processor_kwargs
=
mm_processor_kwargs
,
limit_mm_per_prompt
=
{
"image"
:
1
})
...
...
tests/quantization/test_configs.py
View file @
051eaf6d
...
...
@@ -57,7 +57,8 @@ def test_auto_gptq(model_arg_exptype: Tuple[str, None, str]) -> None:
try
:
model_config
=
ModelConfig
(
model_path
,
model_path
,
task
=
"auto"
,
tokenizer
=
model_path
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
...
...
tests/test_config.py
View file @
051eaf6d
...
...
@@ -2,6 +2,42 @@ import pytest
from
vllm.config
import
ModelConfig
@
pytest
.
mark
.
parametrize
((
"model_id"
,
"expected_task"
),
[
(
"facebook/opt-125m"
,
"generate"
),
(
"intfloat/e5-mistral-7b-instruct"
,
"embedding"
),
])
def
test_auto_task
(
model_id
,
expected_task
):
config
=
ModelConfig
(
model_id
,
task
=
"auto"
,
tokenizer
=
model_id
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"float16"
,
)
assert
config
.
task
==
expected_task
@
pytest
.
mark
.
parametrize
((
"model_id"
,
"bad_task"
),
[
(
"facebook/opt-125m"
,
"embedding"
),
(
"intfloat/e5-mistral-7b-instruct"
,
"generate"
),
])
def
test_incorrect_task
(
model_id
,
bad_task
):
with
pytest
.
raises
(
ValueError
,
match
=
r
"does not support the .* task"
):
ModelConfig
(
model_id
,
task
=
bad_task
,
tokenizer
=
model_id
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"float16"
,
)
MODEL_IDS_EXPECTED
=
[
(
"Qwen/Qwen1.5-7B"
,
32768
),
(
"mistralai/Mistral-7B-v0.1"
,
4096
),
...
...
@@ -14,7 +50,8 @@ def test_disable_sliding_window(model_id_expected):
model_id
,
expected
=
model_id_expected
model_config
=
ModelConfig
(
model_id
,
model_id
,
task
=
"auto"
,
tokenizer
=
model_id
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
...
...
@@ -32,7 +69,8 @@ def test_get_sliding_window():
# when use_sliding_window is False.
qwen2_model_config
=
ModelConfig
(
"Qwen/Qwen1.5-7B"
,
"Qwen/Qwen1.5-7B"
,
task
=
"auto"
,
tokenizer
=
"Qwen/Qwen1.5-7B"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
...
...
@@ -49,7 +87,8 @@ def test_get_sliding_window():
mistral_model_config
=
ModelConfig
(
"mistralai/Mistral-7B-v0.1"
,
"mistralai/Mistral-7B-v0.1"
,
task
=
"auto"
,
tokenizer
=
"mistralai/Mistral-7B-v0.1"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
...
...
@@ -70,7 +109,8 @@ def test_rope_customization():
llama_model_config
=
ModelConfig
(
"meta-llama/Meta-Llama-3-8B-Instruct"
,
"meta-llama/Meta-Llama-3-8B-Instruct"
,
task
=
"auto"
,
tokenizer
=
"meta-llama/Meta-Llama-3-8B-Instruct"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
dtype
=
"float16"
,
...
...
@@ -82,7 +122,8 @@ def test_rope_customization():
llama_model_config
=
ModelConfig
(
"meta-llama/Meta-Llama-3-8B-Instruct"
,
"meta-llama/Meta-Llama-3-8B-Instruct"
,
task
=
"auto"
,
tokenizer
=
"meta-llama/Meta-Llama-3-8B-Instruct"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
dtype
=
"float16"
,
...
...
@@ -98,7 +139,8 @@ def test_rope_customization():
longchat_model_config
=
ModelConfig
(
"lmsys/longchat-13b-16k"
,
"lmsys/longchat-13b-16k"
,
task
=
"auto"
,
tokenizer
=
"lmsys/longchat-13b-16k"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
dtype
=
"float16"
,
...
...
@@ -112,7 +154,8 @@ def test_rope_customization():
longchat_model_config
=
ModelConfig
(
"lmsys/longchat-13b-16k"
,
"lmsys/longchat-13b-16k"
,
task
=
"auto"
,
tokenizer
=
"lmsys/longchat-13b-16k"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
dtype
=
"float16"
,
...
...
tests/test_utils.py
View file @
051eaf6d
...
...
@@ -59,7 +59,7 @@ def test_deprecate_kwargs_always():
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'old_arg'"
):
dummy
(
old_arg
=
1
)
with
error_on_warning
():
with
error_on_warning
(
DeprecationWarning
):
dummy
(
new_arg
=
1
)
...
...
@@ -69,10 +69,10 @@ def test_deprecate_kwargs_never():
def
dummy
(
*
,
old_arg
:
object
=
None
,
new_arg
:
object
=
None
):
pass
with
error_on_warning
():
with
error_on_warning
(
DeprecationWarning
):
dummy
(
old_arg
=
1
)
with
error_on_warning
():
with
error_on_warning
(
DeprecationWarning
):
dummy
(
new_arg
=
1
)
...
...
@@ -86,15 +86,15 @@ def test_deprecate_kwargs_dynamic():
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'old_arg'"
):
dummy
(
old_arg
=
1
)
with
error_on_warning
():
with
error_on_warning
(
DeprecationWarning
):
dummy
(
new_arg
=
1
)
is_deprecated
=
False
with
error_on_warning
():
with
error_on_warning
(
DeprecationWarning
):
dummy
(
old_arg
=
1
)
with
error_on_warning
():
with
error_on_warning
(
DeprecationWarning
):
dummy
(
new_arg
=
1
)
...
...
tests/utils.py
View file @
051eaf6d
...
...
@@ -8,7 +8,7 @@ import time
import
warnings
from
contextlib
import
contextmanager
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Literal
,
Optional
,
Type
,
Union
import
openai
import
pytest
...
...
@@ -454,13 +454,13 @@ def multi_process_parallel(
@
contextmanager
def
error_on_warning
():
def
error_on_warning
(
category
:
Type
[
Warning
]
=
Warning
):
"""
Within the scope of this context manager, tests will fail if any warning
is emitted.
of the given category
is emitted.
"""
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"error"
)
warnings
.
filterwarnings
(
"error"
,
category
=
category
)
yield
...
...
vllm/config.py
View file @
051eaf6d
import
enum
import
json
from
dataclasses
import
dataclass
,
field
,
fields
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Dict
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
Union
)
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Dict
,
Final
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
import
torch
from
transformers
import
PretrainedConfig
...
...
@@ -33,6 +33,9 @@ logger = init_logger(__name__)
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS
=
32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS
=
5120
Task
=
Literal
[
"generate"
,
"embedding"
]
TaskOption
=
Literal
[
"auto"
,
Task
]
class
ModelConfig
:
"""Configuration for the model.
...
...
@@ -41,6 +44,10 @@ class ModelConfig:
model: Name or path of the huggingface model to use.
It is also used as the content for `model_name` tag in metrics
output when `served_model_name` is not specified.
task: The task to use the model for. Each vLLM instance only supports
one task, even if the same model can be used for multiple tasks.
When the model only supports one task, "auto" can be used to select
it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer, and
...
...
@@ -108,6 +115,7 @@ class ModelConfig:
def
__init__
(
self
,
model
:
str
,
task
:
TaskOption
,
tokenizer
:
str
,
tokenizer_mode
:
str
,
trust_remote_code
:
bool
,
...
...
@@ -207,7 +215,11 @@ class ModelConfig:
self
.
override_neuron_config
=
override_neuron_config
if
is_neuron
(
)
else
None
self
.
_verify_embedding_mode
()
supported_tasks
,
task
=
self
.
_resolve_task
(
task
,
self
.
hf_config
)
self
.
supported_tasks
=
supported_tasks
self
.
task
:
Final
=
task
self
.
_verify_quantization
()
self
.
_verify_cuda_graph
()
self
.
_verify_bnb_config
()
...
...
@@ -241,18 +253,41 @@ class ModelConfig:
"either 'auto', 'slow' or 'mistral'."
)
self
.
tokenizer_mode
=
tokenizer_mode
def
_verify_embedding_mode
(
self
)
->
None
:
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
def
_resolve_task
(
self
,
task_option
:
TaskOption
,
hf_config
:
PretrainedConfig
,
)
->
Tuple
[
Set
[
Task
],
Task
]:
architectures
=
getattr
(
hf_config
,
"architectures"
,
[])
task_support
:
Dict
[
Task
,
bool
]
=
{
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate"
:
ModelRegistry
.
is_text_generation_model
(
architectures
),
"embedding"
:
ModelRegistry
.
is_embedding_model
(
architectures
),
}
supported_tasks_lst
:
List
[
Task
]
=
[
task
for
task
,
is_supported
in
task_support
.
items
()
if
is_supported
]
supported_tasks
=
set
(
supported_tasks_lst
)
# TODO: Allow the same model architecture to be specified as either
# generation or embedding model
if
"Phi3VForCausalLM"
in
architectures
:
# Match both remote and local names
embedding_mode
=
"/VLM2Vec"
in
self
.
model
if
task_option
==
"auto"
:
selected_task
=
next
(
iter
(
supported_tasks_lst
))
if
len
(
supported_tasks
)
>
1
:
logger
.
info
(
"This model supports multiple tasks: %s. "
"Defaulting to '%s'."
,
supported_tasks
,
selected_task
)
else
:
embedding_mode
=
ModelRegistry
.
is_embedding_model
(
architectures
)
if
task_option
not
in
supported_tasks
:
msg
=
(
f
"This model does not support the '
{
task_option
}
' task. "
f
"Supported tasks:
{
supported_tasks
}
"
)
raise
ValueError
(
msg
)
self
.
embedding_mode
=
embedding_mode
selected_task
=
task_option
return
supported_tasks
,
selected_task
def
_parse_quant_hf_config
(
self
):
quant_cfg
=
getattr
(
self
.
hf_config
,
"quantization_config"
,
None
)
...
...
@@ -401,7 +436,7 @@ class ModelConfig:
# Async postprocessor is not necessary with embedding mode
# since there is no token generation
if
self
.
embedding
_mode
:
if
self
.
task
==
"
embedding
"
:
self
.
use_async_output_proc
=
False
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
...
...
@@ -582,11 +617,6 @@ class ModelConfig:
(
hasattr
(
self
.
hf_config
,
"text_config"
)
and
getattr
(
self
.
hf_config
.
text_config
,
"is_encoder_decoder"
,
False
)))
@
property
def
is_embedding_model
(
self
)
->
bool
:
"""Extract the embedding model flag."""
return
self
.
embedding_mode
@
property
def
is_multimodal_model
(
self
)
->
bool
:
return
self
.
multimodal_config
is
not
None
...
...
@@ -943,6 +973,7 @@ class SchedulerConfig:
"""Scheduler configuration.
Args:
task: The task to use the model for.
max_num_batched_tokens: Maximum number of tokens to be processed in
a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single
...
...
@@ -957,7 +988,6 @@ class SchedulerConfig:
prompt latency) before scheduling next prompt.
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
embedding_mode: Whether the running model is for embedding.
preemption_mode: Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
...
...
@@ -972,13 +1002,13 @@ class SchedulerConfig:
"""
def
__init__
(
self
,
task
:
Task
,
max_num_batched_tokens
:
Optional
[
int
],
max_num_seqs
:
int
,
max_model_len
:
int
,
num_lookahead_slots
:
int
=
0
,
delay_factor
:
float
=
0.0
,
enable_chunked_prefill
:
bool
=
False
,
embedding_mode
:
bool
=
False
,
is_multimodal_model
:
bool
=
False
,
preemption_mode
:
Optional
[
str
]
=
None
,
num_scheduler_steps
:
int
=
1
,
...
...
@@ -1002,7 +1032,7 @@ class SchedulerConfig:
# for higher throughput.
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
if
embedding
_mode
:
if
task
==
"
embedding
"
:
# For embedding, choose specific value for higher throughput
max_num_batched_tokens
=
max
(
max_num_batched_tokens
,
...
...
@@ -1022,12 +1052,12 @@ class SchedulerConfig:
"Chunked prefill is enabled with max_num_batched_tokens=%d."
,
self
.
max_num_batched_tokens
)
self
.
task
:
Final
=
task
self
.
max_num_seqs
=
max_num_seqs
self
.
max_model_len
=
max_model_len
self
.
num_lookahead_slots
=
num_lookahead_slots
self
.
delay_factor
=
delay_factor
self
.
chunked_prefill_enabled
=
enable_chunked_prefill
self
.
embedding_mode
=
embedding_mode
self
.
preemption_mode
=
preemption_mode
self
.
num_scheduler_steps
=
num_scheduler_steps
self
.
multi_step_stream_outputs
=
multi_step_stream_outputs
...
...
@@ -1239,6 +1269,7 @@ class SpeculativeConfig:
ngram_prompt_lookup_min
=
0
draft_model_config
=
ModelConfig
(
model
=
speculative_model
,
task
=
target_model_config
.
task
,
tokenizer
=
target_model_config
.
tokenizer
,
tokenizer_mode
=
target_model_config
.
tokenizer_mode
,
trust_remote_code
=
target_model_config
.
trust_remote_code
,
...
...
vllm/core/scheduler.py
View file @
051eaf6d
...
...
@@ -313,7 +313,7 @@ class Scheduler:
self
.
lora_config
=
lora_config
version
=
"selfattn"
if
(
self
.
scheduler_config
.
embedding
_mode
if
(
self
.
scheduler_config
.
task
==
"
embedding
"
or
self
.
cache_config
.
is_attention_free
):
version
=
"placeholder"
...
...
vllm/engine/arg_utils.py
View file @
051eaf6d
...
...
@@ -3,7 +3,7 @@ import dataclasses
import
json
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
Type
,
Union
,
cast
)
Tuple
,
Type
,
Union
,
cast
,
get_args
)
import
torch
...
...
@@ -12,7 +12,7 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TokenizerPoolConfig
)
SpeculativeConfig
,
TaskOption
,
TokenizerPoolConfig
)
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
...
...
@@ -84,6 +84,7 @@ class EngineArgs:
model
:
str
=
'facebook/opt-125m'
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
tokenizer
:
Optional
[
str
]
=
None
task
:
TaskOption
=
"auto"
skip_tokenizer_init
:
bool
=
False
tokenizer_mode
:
str
=
'auto'
trust_remote_code
:
bool
=
False
...
...
@@ -198,6 +199,15 @@ class EngineArgs:
type
=
str
,
default
=
EngineArgs
.
model
,
help
=
'Name or path of the huggingface model to use.'
)
parser
.
add_argument
(
'--task'
,
default
=
EngineArgs
.
task
,
choices
=
get_args
(
TaskOption
),
help
=
'The task to use the model for. Each vLLM instance only '
'supports one task, even if the same model can be used for '
'multiple tasks. When the model only supports one task, "auto" '
'can be used to select it; otherwise, you must specify explicitly '
'which task to use.'
)
parser
.
add_argument
(
'--tokenizer'
,
type
=
nullable_str
,
...
...
@@ -838,6 +848,7 @@ class EngineArgs:
def
create_model_config
(
self
)
->
ModelConfig
:
return
ModelConfig
(
model
=
self
.
model
,
task
=
self
.
task
,
# We know this is not None because we set it in __post_init__
tokenizer
=
cast
(
str
,
self
.
tokenizer
),
tokenizer_mode
=
self
.
tokenizer_mode
,
...
...
@@ -1026,13 +1037,13 @@ class EngineArgs:
" please file an issue with detailed information."
)
scheduler_config
=
SchedulerConfig
(
task
=
model_config
.
task
,
max_num_batched_tokens
=
self
.
max_num_batched_tokens
,
max_num_seqs
=
self
.
max_num_seqs
,
max_model_len
=
model_config
.
max_model_len
,
num_lookahead_slots
=
num_lookahead_slots
,
delay_factor
=
self
.
scheduler_delay_factor
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
embedding_mode
=
model_config
.
embedding_mode
,
is_multimodal_model
=
model_config
.
is_multimodal_model
,
preemption_mode
=
self
.
preemption_mode
,
num_scheduler_steps
=
self
.
num_scheduler_steps
,
...
...
vllm/engine/llm_engine.py
View file @
051eaf6d
...
...
@@ -344,7 +344,7 @@ class LLMEngine:
observability_config
=
self
.
observability_config
,
)
if
not
self
.
model_config
.
embedding
_mode
:
if
self
.
model_config
.
task
!=
"
embedding
"
:
self
.
_initialize_kv_caches
()
# If usage stat is enabled, collect relevant info.
...
...
@@ -1116,7 +1116,7 @@ class LLMEngine:
seq_group
.
metrics
.
model_execute_time
=
(
o
.
model_execute_time
)
if
self
.
model_config
.
embedding
_mode
:
if
self
.
model_config
.
task
==
"
embedding
"
:
self
.
_process_sequence_group_outputs
(
seq_group
,
output
)
else
:
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
output
)
...
...
@@ -1855,9 +1855,6 @@ class LLMEngine:
def
is_encoder_decoder_model
(
self
):
return
self
.
input_preprocessor
.
is_encoder_decoder_model
()
def
is_embedding_model
(
self
):
return
self
.
model_config
.
is_embedding_model
def
_validate_model_inputs
(
self
,
inputs
:
Union
[
DecoderOnlyInputs
,
EncoderDecoderInputs
]):
if
self
.
model_config
.
is_multimodal_model
:
...
...
vllm/entrypoints/llm.py
View file @
051eaf6d
...
...
@@ -8,7 +8,7 @@ from tqdm import tqdm
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
BeamSearchSequence
,
get_beam_search_score
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
,
TaskOption
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
apply_hf_chat_template
,
...
...
@@ -29,7 +29,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
,
deprecate_kwargs
,
is_list_of
from
vllm.utils
import
Counter
,
deprecate_args
,
deprecate_kwargs
,
is_list_of
logger
=
init_logger
(
__name__
)
...
...
@@ -108,6 +108,12 @@ class LLM:
DEPRECATE_LEGACY
:
ClassVar
[
bool
]
=
False
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
DEPRECATE_INIT_POSARGS
:
ClassVar
[
bool
]
=
True
"""
A flag to toggle whether to deprecate positional arguments in
:meth:`LLM.__init__`.
"""
@
classmethod
@
contextmanager
def
deprecate_legacy_api
(
cls
):
...
...
@@ -117,6 +123,13 @@ class LLM:
cls
.
DEPRECATE_LEGACY
=
False
@
deprecate_args
(
start_index
=
2
,
# Ignore self and model
is_deprecated
=
lambda
:
LLM
.
DEPRECATE_INIT_POSARGS
,
additional_message
=
(
"All positional arguments other than `model` will be "
"replaced with keyword arguments in an upcoming version."
),
)
def
__init__
(
self
,
model
:
str
,
...
...
@@ -139,6 +152,8 @@ class LLM:
disable_custom_all_reduce
:
bool
=
False
,
disable_async_output_proc
:
bool
=
False
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
# After positional args are removed, move this right below `model`
task
:
TaskOption
=
"auto"
,
**
kwargs
,
)
->
None
:
'''
...
...
@@ -153,6 +168,7 @@ class LLM:
engine_args
=
EngineArgs
(
model
=
model
,
task
=
task
,
tokenizer
=
tokenizer
,
tokenizer_mode
=
tokenizer_mode
,
skip_tokenizer_init
=
skip_tokenizer_init
,
...
...
@@ -316,10 +332,21 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if
self
.
llm_engine
.
model_config
.
embedding_mode
:
raise
ValueError
(
task
=
self
.
llm_engine
.
model_config
.
task
if
task
!=
"generate"
:
messages
=
[
"LLM.generate() is only supported for (conditional) generation "
"models (XForCausalLM, XForConditionalGeneration)."
)
"models (XForCausalLM, XForConditionalGeneration)."
,
]
supported_tasks
=
self
.
llm_engine
.
model_config
.
supported_tasks
if
"generate"
in
supported_tasks
:
messages
.
append
(
"Your model supports the 'generate' task, but is "
f
"currently initialized for the '
{
task
}
' task. Please "
"initialize the model using `--task generate`."
)
raise
ValueError
(
" "
.
join
(
messages
))
if
prompt_token_ids
is
not
None
:
parsed_prompts
=
self
.
_convert_v1_inputs
(
...
...
@@ -692,10 +719,18 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if
not
self
.
llm_engine
.
model_config
.
embedding_mode
:
raise
ValueError
(
"LLM.encode() is only supported for embedding models (XModel)."
)
task
=
self
.
llm_engine
.
model_config
.
task
if
task
!=
"embedding"
:
messages
=
[
"LLM.encode() is only supported for embedding models."
]
supported_tasks
=
self
.
llm_engine
.
model_config
.
supported_tasks
if
"embedding"
in
supported_tasks
:
messages
.
append
(
"Your model supports the 'embedding' task, but is "
f
"currently initialized for the '
{
task
}
' task. Please "
"initialize the model using `--task embedding`."
)
raise
ValueError
(
" "
.
join
(
messages
))
if
prompt_token_ids
is
not
None
:
parsed_prompts
=
self
.
_convert_v1_inputs
(
...
...
@@ -905,6 +940,3 @@ class LLM:
def
_is_encoder_decoder_model
(
self
):
return
self
.
llm_engine
.
is_encoder_decoder_model
()
def
_is_embedding_model
(
self
):
return
self
.
llm_engine
.
is_embedding_model
()
vllm/entrypoints/openai/serving_embedding.py
View file @
051eaf6d
...
...
@@ -83,7 +83,8 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_modules
=
None
,
prompt_adapters
=
None
,
request_logger
=
request_logger
)
self
.
_enabled
=
self
.
_check_embedding_mode
(
model_config
.
embedding_mode
)
self
.
_enabled
=
self
.
_check_embedding_mode
(
model_config
.
task
==
"embedding"
)
async
def
create_embedding
(
self
,
...
...
vllm/utils.py
View file @
051eaf6d
...
...
@@ -1034,10 +1034,54 @@ def identity(value: T) -> T:
F
=
TypeVar
(
'F'
,
bound
=
Callable
[...,
Any
])
def
deprecate_args
(
start_index
:
int
,
is_deprecated
:
Union
[
bool
,
Callable
[[],
bool
]]
=
True
,
additional_message
:
Optional
[
str
]
=
None
,
)
->
Callable
[[
F
],
F
]:
if
not
callable
(
is_deprecated
):
is_deprecated
=
partial
(
identity
,
is_deprecated
)
def
wrapper
(
fn
:
F
)
->
F
:
params
=
inspect
.
signature
(
fn
).
parameters
pos_types
=
(
inspect
.
Parameter
.
POSITIONAL_ONLY
,
inspect
.
Parameter
.
POSITIONAL_OR_KEYWORD
,
)
pos_kws
=
[
kw
for
kw
,
param
in
params
.
items
()
if
param
.
kind
in
pos_types
]
@
wraps
(
fn
)
def
inner
(
*
args
,
**
kwargs
):
if
is_deprecated
():
deprecated_args
=
pos_kws
[
start_index
:
len
(
args
)]
if
deprecated_args
:
msg
=
(
f
"The positional arguments
{
deprecated_args
}
are "
"deprecated and will be removed in a future update."
)
if
additional_message
is
not
None
:
msg
+=
f
"
{
additional_message
}
"
warnings
.
warn
(
DeprecationWarning
(
msg
),
stacklevel
=
3
,
# The inner function takes up one level
)
return
fn
(
*
args
,
**
kwargs
)
return
inner
# type: ignore
return
wrapper
def
deprecate_kwargs
(
*
kws
:
str
,
is_deprecated
:
Union
[
bool
,
Callable
[[],
bool
]]
=
True
,
additional_message
:
Optional
[
str
]
=
None
)
->
Callable
[[
F
],
F
]:
additional_message
:
Optional
[
str
]
=
None
,
)
->
Callable
[[
F
],
F
]:
deprecated_kws
=
set
(
kws
)
if
not
callable
(
is_deprecated
):
...
...
vllm/worker/worker.py
View file @
051eaf6d
...
...
@@ -92,7 +92,7 @@ class Worker(LocalOrDistributedWorkerBase):
ModelRunnerClass
:
Type
[
GPUModelRunnerBase
]
=
ModelRunner
if
model_runner_cls
is
not
None
:
ModelRunnerClass
=
model_runner_cls
elif
self
.
_is_
embedding
_model
()
:
elif
model_config
.
task
==
"
embedding
"
:
ModelRunnerClass
=
EmbeddingModelRunner
elif
self
.
_is_encoder_decoder_model
():
ModelRunnerClass
=
EncoderDecoderModelRunner
...
...
@@ -147,9 +147,6 @@ class Worker(LocalOrDistributedWorkerBase):
def
_is_encoder_decoder_model
(
self
):
return
self
.
model_config
.
is_encoder_decoder_model
def
_is_embedding_model
(
self
):
return
self
.
model_config
.
is_embedding_model
def
init_device
(
self
)
->
None
:
if
self
.
device_config
.
device
.
type
==
"cuda"
:
# torch.distributed.all_reduce does not free the input tensor until
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment