Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d2f058e7
Unverified
Commit
d2f058e7
authored
Dec 01, 2024
by
Cyrus Leung
Committed by
GitHub
Dec 01, 2024
Browse files
[Misc] Rename embedding classes to pooling (#10801)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
f877a7d1
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
153 additions
and
110 deletions
+153
-110
examples/offline_inference_embedding.py
examples/offline_inference_embedding.py
+1
-1
tests/entrypoints/llm/test_encode.py
tests/entrypoints/llm/test_encode.py
+3
-3
tests/models/test_registry.py
tests/models/test_registry.py
+2
-2
tests/worker/test_model_input.py
tests/worker/test_model_input.py
+2
-2
vllm/__init__.py
vllm/__init__.py
+27
-4
vllm/config.py
vllm/config.py
+1
-1
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+12
-12
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+4
-4
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+7
-7
vllm/engine/protocol.py
vllm/engine/protocol.py
+2
-3
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+15
-15
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+6
-6
vllm/entrypoints/openai/serving_score.py
vllm/entrypoints/openai/serving_score.py
+5
-5
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+5
-6
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+3
-3
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+2
-2
vllm/model_executor/models/interfaces_base.py
vllm/model_executor/models/interfaces_base.py
+7
-8
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+8
-8
vllm/outputs.py
vllm/outputs.py
+39
-16
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+2
-2
No files found.
examples/offline_inference_embedding.py
View file @
d2f058e7
...
@@ -10,7 +10,7 @@ prompts = [
...
@@ -10,7 +10,7 @@ prompts = [
# Create an LLM.
# Create an LLM.
model
=
LLM
(
model
=
"intfloat/e5-mistral-7b-instruct"
,
enforce_eager
=
True
)
model
=
LLM
(
model
=
"intfloat/e5-mistral-7b-instruct"
,
enforce_eager
=
True
)
# Generate embedding. The output is a list of
Embedd
ingRequestOutputs.
# Generate embedding. The output is a list of
Pool
ingRequestOutputs.
outputs
=
model
.
encode
(
prompts
)
outputs
=
model
.
encode
(
prompts
)
# Print the outputs.
# Print the outputs.
for
output
in
outputs
:
for
output
in
outputs
:
...
...
tests/entrypoints/llm/test_encode.py
View file @
d2f058e7
...
@@ -3,7 +3,7 @@ from typing import List
...
@@ -3,7 +3,7 @@ from typing import List
import
pytest
import
pytest
from
vllm
import
LLM
,
Embedd
ingRequestOutput
,
PoolingParams
from
vllm
import
LLM
,
PoolingParams
,
Pool
ingRequestOutput
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.distributed
import
cleanup_dist_env_and_memory
MODEL_NAME
=
"intfloat/e5-mistral-7b-instruct"
MODEL_NAME
=
"intfloat/e5-mistral-7b-instruct"
...
@@ -43,8 +43,8 @@ def llm():
...
@@ -43,8 +43,8 @@ def llm():
cleanup_dist_env_and_memory
()
cleanup_dist_env_and_memory
()
def
assert_outputs_equal
(
o1
:
List
[
Embedd
ingRequestOutput
],
def
assert_outputs_equal
(
o1
:
List
[
Pool
ingRequestOutput
],
o2
:
List
[
Embedd
ingRequestOutput
]):
o2
:
List
[
Pool
ingRequestOutput
]):
assert
[
o
.
outputs
for
o
in
o1
]
==
[
o
.
outputs
for
o
in
o2
]
assert
[
o
.
outputs
for
o
in
o1
]
==
[
o
.
outputs
for
o
in
o2
]
...
...
tests/models/test_registry.py
View file @
d2f058e7
...
@@ -3,7 +3,7 @@ import warnings
...
@@ -3,7 +3,7 @@ import warnings
import
pytest
import
pytest
import
torch.cuda
import
torch.cuda
from
vllm.model_executor.models
import
(
is_
embedd
ing_model
,
from
vllm.model_executor.models
import
(
is_
pool
ing_model
,
is_text_generation_model
,
is_text_generation_model
,
supports_multimodal
)
supports_multimodal
)
from
vllm.model_executor.models.adapters
import
as_embedding_model
from
vllm.model_executor.models.adapters
import
as_embedding_model
...
@@ -31,7 +31,7 @@ def test_registry_imports(model_arch):
...
@@ -31,7 +31,7 @@ def test_registry_imports(model_arch):
# All vLLM models should be convertible to an embedding model
# All vLLM models should be convertible to an embedding model
embed_model
=
as_embedding_model
(
model_cls
)
embed_model
=
as_embedding_model
(
model_cls
)
assert
is_
embedd
ing_model
(
embed_model
)
assert
is_
pool
ing_model
(
embed_model
)
if
model_arch
in
_MULTIMODAL_MODELS
:
if
model_arch
in
_MULTIMODAL_MODELS
:
assert
supports_multimodal
(
model_cls
)
assert
supports_multimodal
(
model_cls
)
...
...
tests/worker/test_model_input.py
View file @
d2f058e7
...
@@ -8,10 +8,10 @@ from vllm.attention.backends.abstract import AttentionBackend
...
@@ -8,10 +8,10 @@ from vllm.attention.backends.abstract import AttentionBackend
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.worker.embedding_model_runner
import
(
ModelInputForGPUWithPoolingMetadata
)
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
from
vllm.worker.multi_step_model_runner
import
StatefulModelInput
from
vllm.worker.multi_step_model_runner
import
StatefulModelInput
from
vllm.worker.pooling_model_runner
import
(
ModelInputForGPUWithPoolingMetadata
)
class
MockAttentionBackend
(
AttentionBackend
):
class
MockAttentionBackend
(
AttentionBackend
):
...
...
vllm/__init__.py
View file @
d2f058e7
...
@@ -7,8 +7,8 @@ from vllm.entrypoints.llm import LLM
...
@@ -7,8 +7,8 @@ from vllm.entrypoints.llm import LLM
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
PromptType
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
PromptType
,
TextPrompt
,
TokensPrompt
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.outputs
import
(
CompletionOutput
,
Embedd
ingOutput
,
from
vllm.outputs
import
(
CompletionOutput
,
Pool
ingOutput
,
Embedd
ingRequestOutput
,
RequestOutput
)
Pool
ingRequestOutput
,
RequestOutput
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
...
@@ -25,8 +25,8 @@ __all__ = [
...
@@ -25,8 +25,8 @@ __all__ = [
"SamplingParams"
,
"SamplingParams"
,
"RequestOutput"
,
"RequestOutput"
,
"CompletionOutput"
,
"CompletionOutput"
,
"
Embedd
ingOutput"
,
"
Pool
ingOutput"
,
"
Embedd
ingRequestOutput"
,
"
Pool
ingRequestOutput"
,
"LLMEngine"
,
"LLMEngine"
,
"EngineArgs"
,
"EngineArgs"
,
"AsyncLLMEngine"
,
"AsyncLLMEngine"
,
...
@@ -34,3 +34,26 @@ __all__ = [
...
@@ -34,3 +34,26 @@ __all__ = [
"initialize_ray_cluster"
,
"initialize_ray_cluster"
,
"PoolingParams"
,
"PoolingParams"
,
]
]
def
__getattr__
(
name
:
str
):
import
warnings
if
name
==
"EmbeddingOutput"
:
msg
=
(
"EmbeddingOutput has been renamed to PoolingOutput. "
"The original name will be removed in an upcoming version."
)
warnings
.
warn
(
DeprecationWarning
(
msg
),
stacklevel
=
2
)
return
PoolingOutput
if
name
==
"EmbeddingRequestOutput"
:
msg
=
(
"EmbeddingRequestOutput has been renamed to "
"PoolingRequestOutput. "
"The original name will be removed in an upcoming version."
)
warnings
.
warn
(
DeprecationWarning
(
msg
),
stacklevel
=
2
)
return
PoolingRequestOutput
raise
AttributeError
(
f
"module
{
__name__
!
r
}
has no attribute
{
name
!
r
}
"
)
vllm/config.py
View file @
d2f058e7
...
@@ -359,7 +359,7 @@ class ModelConfig:
...
@@ -359,7 +359,7 @@ class ModelConfig:
# NOTE: Listed from highest to lowest priority,
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
# in case the model supports multiple of them
"generate"
:
ModelRegistry
.
is_text_generation_model
(
architectures
),
"generate"
:
ModelRegistry
.
is_text_generation_model
(
architectures
),
"embedding"
:
ModelRegistry
.
is_
embedd
ing_model
(
architectures
),
"embedding"
:
ModelRegistry
.
is_
pool
ing_model
(
architectures
),
}
}
supported_tasks_lst
:
List
[
_Task
]
=
[
supported_tasks_lst
:
List
[
_Task
]
=
[
task
for
task
,
is_supported
in
task_support
.
items
()
if
is_supported
task
for
task
,
is_supported
in
task_support
.
items
()
if
is_supported
...
...
vllm/engine/async_llm_engine.py
View file @
d2f058e7
...
@@ -25,7 +25,7 @@ from vllm.lora.request import LoRARequest
...
@@ -25,7 +25,7 @@ from vllm.lora.request import LoRARequest
from
vllm.model_executor.guided_decoding
import
(
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
get_guided_decoding_logits_processor
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
Embedd
ingRequestOutput
,
RequestOutput
from
vllm.outputs
import
Pool
ingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
...
@@ -74,7 +74,7 @@ STOP_ITERATION = Exception() # Sentinel
...
@@ -74,7 +74,7 @@ STOP_ITERATION = Exception() # Sentinel
class
AsyncStream
:
class
AsyncStream
:
"""A stream of RequestOutputs or
Embedd
ingRequestOutputs for a request
"""A stream of RequestOutputs or
Pool
ingRequestOutputs for a request
that can be iterated over asynchronously via an async generator."""
that can be iterated over asynchronously via an async generator."""
def
__init__
(
self
,
request_id
:
str
,
cancel
:
Callable
[[
str
],
None
])
->
None
:
def
__init__
(
self
,
request_id
:
str
,
cancel
:
Callable
[[
str
],
None
])
->
None
:
...
@@ -83,7 +83,7 @@ class AsyncStream:
...
@@ -83,7 +83,7 @@ class AsyncStream:
self
.
_queue
:
asyncio
.
Queue
=
asyncio
.
Queue
()
self
.
_queue
:
asyncio
.
Queue
=
asyncio
.
Queue
()
self
.
_finished
=
False
self
.
_finished
=
False
def
put
(
self
,
item
:
Union
[
RequestOutput
,
Embedd
ingRequestOutput
,
def
put
(
self
,
item
:
Union
[
RequestOutput
,
Pool
ingRequestOutput
,
Exception
])
->
None
:
Exception
])
->
None
:
if
not
self
.
_finished
:
if
not
self
.
_finished
:
self
.
_queue
.
put_nowait
(
item
)
self
.
_queue
.
put_nowait
(
item
)
...
@@ -103,7 +103,7 @@ class AsyncStream:
...
@@ -103,7 +103,7 @@ class AsyncStream:
async
def
generator
(
async
def
generator
(
self
self
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
],
None
]:
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
Pool
ingRequestOutput
],
None
]:
try
:
try
:
while
True
:
while
True
:
result
=
await
self
.
_queue
.
get
()
result
=
await
self
.
_queue
.
get
()
...
@@ -154,7 +154,7 @@ class RequestTracker:
...
@@ -154,7 +154,7 @@ class RequestTracker:
def
process_request_output
(
self
,
def
process_request_output
(
self
,
request_output
:
Union
[
RequestOutput
,
request_output
:
Union
[
RequestOutput
,
Embedd
ingRequestOutput
],
Pool
ingRequestOutput
],
*
,
*
,
verbose
:
bool
=
False
)
->
None
:
verbose
:
bool
=
False
)
->
None
:
"""Process a request output from the engine."""
"""Process a request output from the engine."""
...
@@ -265,7 +265,7 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -265,7 +265,7 @@ class _AsyncLLMEngine(LLMEngine):
async
def
step_async
(
async
def
step_async
(
self
,
virtual_engine
:
int
self
,
virtual_engine
:
int
)
->
List
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
]]:
)
->
List
[
Union
[
RequestOutput
,
Pool
ingRequestOutput
]]:
"""Performs one decoding iteration and returns newly generated results.
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
The workers are ran asynchronously if possible.
...
@@ -907,7 +907,7 @@ class AsyncLLMEngine(EngineClient):
...
@@ -907,7 +907,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
Coroutine
[
None
,
None
,
AsyncGenerator
[
Union
[
)
->
Coroutine
[
None
,
None
,
AsyncGenerator
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
],
None
]]:
RequestOutput
,
Pool
ingRequestOutput
],
None
]]:
...
...
@
overload
@
overload
...
@@ -922,7 +922,7 @@ class AsyncLLMEngine(EngineClient):
...
@@ -922,7 +922,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
Coroutine
[
None
,
None
,
AsyncGenerator
[
Union
[
)
->
Coroutine
[
None
,
None
,
AsyncGenerator
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
],
None
]]:
RequestOutput
,
Pool
ingRequestOutput
],
None
]]:
...
...
@
deprecate_kwargs
(
@
deprecate_kwargs
(
...
@@ -941,7 +941,7 @@ class AsyncLLMEngine(EngineClient):
...
@@ -941,7 +941,7 @@ class AsyncLLMEngine(EngineClient):
priority
:
int
=
0
,
priority
:
int
=
0
,
*
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
],
None
]:
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
Pool
ingRequestOutput
],
None
]:
if
inputs
is
not
None
:
if
inputs
is
not
None
:
prompt
=
inputs
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
assert
prompt
is
not
None
and
params
is
not
None
...
@@ -1070,7 +1070,7 @@ class AsyncLLMEngine(EngineClient):
...
@@ -1070,7 +1070,7 @@ class AsyncLLMEngine(EngineClient):
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
]:
)
->
AsyncGenerator
[
Pool
ingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model.
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
Generate outputs for a request. This method is a coroutine. It adds the
...
@@ -1088,7 +1088,7 @@ class AsyncLLMEngine(EngineClient):
...
@@ -1088,7 +1088,7 @@ class AsyncLLMEngine(EngineClient):
Only applicable with priority scheduling.
Only applicable with priority scheduling.
Yields:
Yields:
The output `
Embedd
ingRequestOutput` objects from the LLMEngine
The output `
Pool
ingRequestOutput` objects from the LLMEngine
for the request.
for the request.
Details:
Details:
...
@@ -1141,7 +1141,7 @@ class AsyncLLMEngine(EngineClient):
...
@@ -1141,7 +1141,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
priority
=
priority
,
priority
=
priority
,
):
):
yield
LLMEngine
.
validate_output
(
output
,
Embedd
ingRequestOutput
)
yield
LLMEngine
.
validate_output
(
output
,
Pool
ingRequestOutput
)
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
"""Abort a request.
"""Abort a request.
...
...
vllm/engine/llm_engine.py
View file @
d2f058e7
...
@@ -40,7 +40,7 @@ from vllm.model_executor.guided_decoding import (
...
@@ -40,7 +40,7 @@ from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor
)
get_local_guided_decoding_logits_processor
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.outputs
import
(
Embedd
ingRequestOutput
,
RequestOutput
,
from
vllm.outputs
import
(
Pool
ingRequestOutput
,
RequestOutput
,
RequestOutputFactory
)
RequestOutputFactory
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
...
@@ -80,7 +80,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
...
@@ -80,7 +80,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
_G
=
TypeVar
(
"_G"
,
bound
=
BaseTokenizerGroup
,
default
=
BaseTokenizerGroup
)
_G
=
TypeVar
(
"_G"
,
bound
=
BaseTokenizerGroup
,
default
=
BaseTokenizerGroup
)
_O
=
TypeVar
(
"_O"
,
RequestOutput
,
Embedd
ingRequestOutput
)
_O
=
TypeVar
(
"_O"
,
RequestOutput
,
Pool
ingRequestOutput
)
@
dataclass
@
dataclass
...
@@ -112,7 +112,7 @@ class SchedulerContext:
...
@@ -112,7 +112,7 @@ class SchedulerContext:
def
__init__
(
self
,
multi_step_stream_outputs
:
bool
=
False
):
def
__init__
(
self
,
multi_step_stream_outputs
:
bool
=
False
):
self
.
output_queue
:
Deque
[
OutputData
]
=
deque
()
self
.
output_queue
:
Deque
[
OutputData
]
=
deque
()
self
.
request_outputs
:
List
[
Union
[
RequestOutput
,
self
.
request_outputs
:
List
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
]]
=
[]
Pool
ingRequestOutput
]]
=
[]
self
.
seq_group_metadata_list
:
Optional
[
self
.
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
List
[
SequenceGroupMetadata
]]
=
None
self
.
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
self
.
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
...
@@ -1314,7 +1314,7 @@ class LLMEngine:
...
@@ -1314,7 +1314,7 @@ class LLMEngine:
else
:
else
:
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
]]:
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
Pool
ingRequestOutput
]]:
"""Performs one decoding iteration and returns newly generated results.
"""Performs one decoding iteration and returns newly generated results.
.. figure:: https://i.imgur.com/sv2HssD.png
.. figure:: https://i.imgur.com/sv2HssD.png
...
...
vllm/engine/multiprocessing/client.py
View file @
d2f058e7
...
@@ -35,7 +35,7 @@ from vllm.inputs.preprocess import InputPreprocessor
...
@@ -35,7 +35,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
Embedd
ingRequestOutput
,
RequestOutput
from
vllm.outputs
import
Pool
ingRequestOutput
,
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
...
@@ -495,7 +495,7 @@ class MQLLMEngineClient(EngineClient):
...
@@ -495,7 +495,7 @@ class MQLLMEngineClient(EngineClient):
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
]:
)
->
AsyncGenerator
[
Pool
ingRequestOutput
,
None
]:
...
...
@
overload
@
overload
...
@@ -507,7 +507,7 @@ class MQLLMEngineClient(EngineClient):
...
@@ -507,7 +507,7 @@ class MQLLMEngineClient(EngineClient):
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
]:
)
->
AsyncGenerator
[
Pool
ingRequestOutput
,
None
]:
...
...
@
deprecate_kwargs
(
@
deprecate_kwargs
(
...
@@ -524,7 +524,7 @@ class MQLLMEngineClient(EngineClient):
...
@@ -524,7 +524,7 @@ class MQLLMEngineClient(EngineClient):
priority
:
int
=
0
,
priority
:
int
=
0
,
*
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
# DEPRECATED
inputs
:
Optional
[
PromptType
]
=
None
# DEPRECATED
)
->
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
]:
)
->
AsyncGenerator
[
Pool
ingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model.
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
Generate outputs for a request. This method is a coroutine. It adds the
...
@@ -540,7 +540,7 @@ class MQLLMEngineClient(EngineClient):
...
@@ -540,7 +540,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: OpenTelemetry trace headers.
trace_headers: OpenTelemetry trace headers.
Yields:
Yields:
The output `
Embedd
ingRequestOutput` objects from the LLMEngine
The output `
Pool
ingRequestOutput` objects from the LLMEngine
for the request.
for the request.
"""
"""
if
inputs
is
not
None
:
if
inputs
is
not
None
:
...
@@ -549,7 +549,7 @@ class MQLLMEngineClient(EngineClient):
...
@@ -549,7 +549,7 @@ class MQLLMEngineClient(EngineClient):
and
request_id
is
not
None
)
and
request_id
is
not
None
)
return
cast
(
return
cast
(
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
],
AsyncGenerator
[
Pool
ingRequestOutput
,
None
],
self
.
_process_request
(
prompt
,
self
.
_process_request
(
prompt
,
pooling_params
,
pooling_params
,
request_id
,
request_id
,
...
@@ -567,7 +567,7 @@ class MQLLMEngineClient(EngineClient):
...
@@ -567,7 +567,7 @@ class MQLLMEngineClient(EngineClient):
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
Union
[
AsyncGenerator
[
RequestOutput
,
None
],
AsyncGenerator
[
)
->
Union
[
AsyncGenerator
[
RequestOutput
,
None
],
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
]]:
Pool
ingRequestOutput
,
None
]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out.
# If already dead, error out.
...
...
vllm/engine/protocol.py
View file @
d2f058e7
...
@@ -11,8 +11,7 @@ from vllm.inputs.preprocess import InputPreprocessor
...
@@ -11,8 +11,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
(
CompletionOutput
,
EmbeddingRequestOutput
,
from
vllm.outputs
import
CompletionOutput
,
PoolingRequestOutput
,
RequestOutput
RequestOutput
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
...
@@ -209,7 +208,7 @@ class EngineClient(ABC):
...
@@ -209,7 +208,7 @@ class EngineClient(ABC):
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
]:
)
->
AsyncGenerator
[
Pool
ingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model."""
"""Generate outputs for a request from an embedding model."""
...
...
...
...
vllm/entrypoints/llm.py
View file @
d2f058e7
...
@@ -26,7 +26,7 @@ from vllm.logger import init_logger
...
@@ -26,7 +26,7 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding.guided_fields
import
(
from
vllm.model_executor.guided_decoding.guided_fields
import
(
GuidedDecodingRequest
,
LLMGuidedOptions
)
GuidedDecodingRequest
,
LLMGuidedOptions
)
from
vllm.outputs
import
Embedd
ingRequestOutput
,
RequestOutput
from
vllm.outputs
import
Pool
ingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
(
BeamSearchParams
,
GuidedDecodingParams
,
from
vllm.sampling_params
import
(
BeamSearchParams
,
GuidedDecodingParams
,
...
@@ -679,7 +679,7 @@ class LLM:
...
@@ -679,7 +679,7 @@ class LLM:
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
)
->
List
[
Embedd
ingRequestOutput
]:
)
->
List
[
Pool
ingRequestOutput
]:
...
...
@
overload
# LEGACY: multi (prompt + optional token ids)
@
overload
# LEGACY: multi (prompt + optional token ids)
...
@@ -691,7 +691,7 @@ class LLM:
...
@@ -691,7 +691,7 @@ class LLM:
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
)
->
List
[
Embedd
ingRequestOutput
]:
)
->
List
[
Pool
ingRequestOutput
]:
...
...
@
overload
# LEGACY: single (token ids + optional prompt)
@
overload
# LEGACY: single (token ids + optional prompt)
...
@@ -704,7 +704,7 @@ class LLM:
...
@@ -704,7 +704,7 @@ class LLM:
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
List
[
int
],
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
)
->
List
[
Embedd
ingRequestOutput
]:
)
->
List
[
Pool
ingRequestOutput
]:
...
...
@
overload
# LEGACY: multi (token ids + optional prompt)
@
overload
# LEGACY: multi (token ids + optional prompt)
...
@@ -717,7 +717,7 @@ class LLM:
...
@@ -717,7 +717,7 @@ class LLM:
prompt_token_ids
:
List
[
List
[
int
]],
prompt_token_ids
:
List
[
List
[
int
]],
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
)
->
List
[
Embedd
ingRequestOutput
]:
)
->
List
[
Pool
ingRequestOutput
]:
...
...
@
overload
# LEGACY: single or multi token ids [pos-only]
@
overload
# LEGACY: single or multi token ids [pos-only]
...
@@ -728,7 +728,7 @@ class LLM:
...
@@ -728,7 +728,7 @@ class LLM:
prompt_token_ids
:
Union
[
List
[
int
],
List
[
List
[
int
]]],
prompt_token_ids
:
Union
[
List
[
int
],
List
[
List
[
int
]]],
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
)
->
List
[
Embedd
ingRequestOutput
]:
)
->
List
[
Pool
ingRequestOutput
]:
...
...
@
overload
@
overload
...
@@ -741,7 +741,7 @@ class LLM:
...
@@ -741,7 +741,7 @@ class LLM:
Sequence
[
PoolingParams
]]]
=
None
,
Sequence
[
PoolingParams
]]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
)
->
List
[
Embedd
ingRequestOutput
]:
)
->
List
[
Pool
ingRequestOutput
]:
...
...
@
deprecate_kwargs
(
@
deprecate_kwargs
(
...
@@ -759,7 +759,7 @@ class LLM:
...
@@ -759,7 +759,7 @@ class LLM:
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
List
[
Embedd
ingRequestOutput
]:
)
->
List
[
Pool
ingRequestOutput
]:
"""Generates the completions for the input prompts.
"""Generates the completions for the input prompts.
This class automatically batches the given prompts, considering
This class automatically batches the given prompts, considering
...
@@ -778,7 +778,7 @@ class LLM:
...
@@ -778,7 +778,7 @@ class LLM:
generation, if any.
generation, if any.
Returns:
Returns:
A list of ``
Embedd
ingRequestOutput`` objects containing the
A list of ``
Pool
ingRequestOutput`` objects containing the
generated embeddings in the same order as the input prompts.
generated embeddings in the same order as the input prompts.
Note:
Note:
...
@@ -821,7 +821,7 @@ class LLM:
...
@@ -821,7 +821,7 @@ class LLM:
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
return
self
.
engine_class
.
validate_outputs
(
outputs
,
Embedd
ingRequestOutput
)
Pool
ingRequestOutput
)
def
score
(
def
score
(
self
,
self
,
...
@@ -832,7 +832,7 @@ class LLM:
...
@@ -832,7 +832,7 @@ class LLM:
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
List
[
Embedd
ingRequestOutput
]:
)
->
List
[
Pool
ingRequestOutput
]:
"""Generates similarity scores for all pairs <text,text_pair>.
"""Generates similarity scores for all pairs <text,text_pair>.
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
...
@@ -854,7 +854,7 @@ class LLM:
...
@@ -854,7 +854,7 @@ class LLM:
generation, if any.
generation, if any.
Returns:
Returns:
A list of ``
Embedd
ingRequestOutput`` objects containing the
A list of ``
Pool
ingRequestOutput`` objects containing the
generated scores in the same order as the input prompts.
generated scores in the same order as the input prompts.
"""
"""
task
=
self
.
llm_engine
.
model_config
.
task
task
=
self
.
llm_engine
.
model_config
.
task
...
@@ -943,7 +943,7 @@ class LLM:
...
@@ -943,7 +943,7 @@ class LLM:
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
return
self
.
engine_class
.
validate_outputs
(
outputs
,
Embedd
ingRequestOutput
)
Pool
ingRequestOutput
)
def
start_profile
(
self
)
->
None
:
def
start_profile
(
self
)
->
None
:
self
.
llm_engine
.
start_profile
()
self
.
llm_engine
.
start_profile
()
...
@@ -1085,7 +1085,7 @@ class LLM:
...
@@ -1085,7 +1085,7 @@ class LLM:
def
_run_engine
(
def
_run_engine
(
self
,
*
,
use_tqdm
:
bool
self
,
*
,
use_tqdm
:
bool
)
->
List
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
]]:
)
->
List
[
Union
[
RequestOutput
,
Pool
ingRequestOutput
]]:
# Initialize tqdm.
# Initialize tqdm.
if
use_tqdm
:
if
use_tqdm
:
num_requests
=
self
.
llm_engine
.
get_num_unfinished_requests
()
num_requests
=
self
.
llm_engine
.
get_num_unfinished_requests
()
...
@@ -1098,7 +1098,7 @@ class LLM:
...
@@ -1098,7 +1098,7 @@ class LLM:
)
)
# Run the engine.
# Run the engine.
outputs
:
List
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
]]
=
[]
outputs
:
List
[
Union
[
RequestOutput
,
Pool
ingRequestOutput
]]
=
[]
total_in_toks
=
0
total_in_toks
=
0
total_out_toks
=
0
total_out_toks
=
0
while
self
.
llm_engine
.
has_unfinished_requests
():
while
self
.
llm_engine
.
has_unfinished_requests
():
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
d2f058e7
...
@@ -18,14 +18,14 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
...
@@ -18,14 +18,14 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
ErrorResponse
,
UsageInfo
)
ErrorResponse
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
,
OpenAIServing
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
,
OpenAIServing
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
Embedd
ingOutput
,
Embedd
ingRequestOutput
from
vllm.outputs
import
Pool
ingOutput
,
Pool
ingRequestOutput
from
vllm.utils
import
merge_async_iterators
,
random_uuid
from
vllm.utils
import
merge_async_iterators
,
random_uuid
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
_get_embedding
(
def
_get_embedding
(
output
:
Embedd
ingOutput
,
output
:
Pool
ingOutput
,
encoding_format
:
Literal
[
"float"
,
"base64"
],
encoding_format
:
Literal
[
"float"
,
"base64"
],
)
->
Union
[
List
[
float
],
str
]:
)
->
Union
[
List
[
float
],
str
]:
if
encoding_format
==
"float"
:
if
encoding_format
==
"float"
:
...
@@ -40,7 +40,7 @@ def _get_embedding(
...
@@ -40,7 +40,7 @@ def _get_embedding(
def
request_output_to_embedding_response
(
def
request_output_to_embedding_response
(
final_res_batch
:
List
[
Embedd
ingRequestOutput
],
request_id
:
str
,
final_res_batch
:
List
[
Pool
ingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
created_time
:
int
,
model_name
:
str
,
encoding_format
:
Literal
[
"float"
,
"base64"
])
->
EmbeddingResponse
:
encoding_format
:
Literal
[
"float"
,
"base64"
])
->
EmbeddingResponse
:
data
:
List
[
EmbeddingResponseData
]
=
[]
data
:
List
[
EmbeddingResponseData
]
=
[]
...
@@ -169,7 +169,7 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -169,7 +169,7 @@ class OpenAIServingEmbedding(OpenAIServing):
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
generators
:
List
[
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
]]
=
[]
generators
:
List
[
AsyncGenerator
[
Pool
ingRequestOutput
,
None
]]
=
[]
try
:
try
:
pooling_params
=
request
.
to_pooling_params
()
pooling_params
=
request
.
to_pooling_params
()
...
@@ -207,7 +207,7 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -207,7 +207,7 @@ class OpenAIServingEmbedding(OpenAIServing):
num_prompts
=
len
(
engine_prompts
)
num_prompts
=
len
(
engine_prompts
)
# Non-streaming response
# Non-streaming response
final_res_batch
:
List
[
Optional
[
Embedd
ingRequestOutput
]]
final_res_batch
:
List
[
Optional
[
Pool
ingRequestOutput
]]
final_res_batch
=
[
None
]
*
num_prompts
final_res_batch
=
[
None
]
*
num_prompts
try
:
try
:
async
for
i
,
res
in
result_generator
:
async
for
i
,
res
in
result_generator
:
...
@@ -215,7 +215,7 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -215,7 +215,7 @@ class OpenAIServingEmbedding(OpenAIServing):
assert
all
(
final_res
is
not
None
for
final_res
in
final_res_batch
)
assert
all
(
final_res
is
not
None
for
final_res
in
final_res_batch
)
final_res_batch_checked
=
cast
(
List
[
Embedd
ingRequestOutput
],
final_res_batch_checked
=
cast
(
List
[
Pool
ingRequestOutput
],
final_res_batch
)
final_res_batch
)
response
=
request_output_to_embedding_response
(
response
=
request_output_to_embedding_response
(
...
...
vllm/entrypoints/openai/serving_score.py
View file @
d2f058e7
...
@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
...
@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
,
OpenAIServing
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
,
OpenAIServing
from
vllm.inputs.data
import
TokensPrompt
from
vllm.inputs.data
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
Embedd
ingRequestOutput
from
vllm.outputs
import
Pool
ingRequestOutput
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
from
vllm.utils
import
make_async
,
merge_async_iterators
,
random_uuid
from
vllm.utils
import
make_async
,
merge_async_iterators
,
random_uuid
...
@@ -21,7 +21,7 @@ logger = init_logger(__name__)
...
@@ -21,7 +21,7 @@ logger = init_logger(__name__)
def
request_output_to_score_response
(
def
request_output_to_score_response
(
final_res_batch
:
List
[
Embedd
ingRequestOutput
],
request_id
:
str
,
final_res_batch
:
List
[
Pool
ingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
)
->
ScoreResponse
:
created_time
:
int
,
model_name
:
str
)
->
ScoreResponse
:
data
:
List
[
ScoreResponseData
]
=
[]
data
:
List
[
ScoreResponseData
]
=
[]
score
=
None
score
=
None
...
@@ -133,7 +133,7 @@ class OpenAIServingScores(OpenAIServing):
...
@@ -133,7 +133,7 @@ class OpenAIServingScores(OpenAIServing):
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
generators
:
List
[
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
]]
=
[]
generators
:
List
[
AsyncGenerator
[
Pool
ingRequestOutput
,
None
]]
=
[]
input_pairs
=
make_pairs
(
request
.
text_1
,
request
.
text_2
)
input_pairs
=
make_pairs
(
request
.
text_1
,
request
.
text_2
)
...
@@ -194,7 +194,7 @@ class OpenAIServingScores(OpenAIServing):
...
@@ -194,7 +194,7 @@ class OpenAIServingScores(OpenAIServing):
num_prompts
=
len
(
engine_prompts
)
num_prompts
=
len
(
engine_prompts
)
# Non-streaming response
# Non-streaming response
final_res_batch
:
List
[
Optional
[
Embedd
ingRequestOutput
]]
final_res_batch
:
List
[
Optional
[
Pool
ingRequestOutput
]]
final_res_batch
=
[
None
]
*
num_prompts
final_res_batch
=
[
None
]
*
num_prompts
try
:
try
:
...
@@ -203,7 +203,7 @@ class OpenAIServingScores(OpenAIServing):
...
@@ -203,7 +203,7 @@ class OpenAIServingScores(OpenAIServing):
assert
all
(
final_res
is
not
None
for
final_res
in
final_res_batch
)
assert
all
(
final_res
is
not
None
for
final_res
in
final_res_batch
)
final_res_batch_checked
=
cast
(
List
[
Embedd
ingRequestOutput
],
final_res_batch_checked
=
cast
(
List
[
Pool
ingRequestOutput
],
final_res_batch
)
final_res_batch
)
response
=
request_output_to_score_response
(
response
=
request_output_to_score_response
(
...
...
vllm/model_executor/models/__init__.py
View file @
d2f058e7
from
.interfaces
import
(
HasInnerState
,
SupportsLoRA
,
SupportsMultiModal
,
from
.interfaces
import
(
HasInnerState
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
,
has_inner_state
,
supports_lora
,
SupportsPP
,
has_inner_state
,
supports_lora
,
supports_multimodal
,
supports_pp
)
supports_multimodal
,
supports_pp
)
from
.interfaces_base
import
(
VllmModelForEmbedding
,
from
.interfaces_base
import
(
VllmModelForPooling
,
VllmModelForTextGeneration
,
VllmModelForTextGeneration
,
is_embedding_model
,
is_pooling_model
,
is_text_generation_model
)
is_text_generation_model
)
from
.registry
import
ModelRegistry
from
.registry
import
ModelRegistry
__all__
=
[
__all__
=
[
"ModelRegistry"
,
"ModelRegistry"
,
"VllmModelFor
Embedd
ing"
,
"VllmModelFor
Pool
ing"
,
"is_
embedd
ing_model"
,
"is_
pool
ing_model"
,
"VllmModelForTextGeneration"
,
"VllmModelForTextGeneration"
,
"is_text_generation_model"
,
"is_text_generation_model"
,
"HasInnerState"
,
"HasInnerState"
,
...
...
vllm/model_executor/models/adapters.py
View file @
d2f058e7
...
@@ -4,7 +4,7 @@ from typing import Any, TypeVar
...
@@ -4,7 +4,7 @@ from typing import Any, TypeVar
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.interfaces_base
import
VllmModelFor
Embedd
ing
,
is_
embedd
ing_model
from
.interfaces_base
import
VllmModelFor
Pool
ing
,
is_
pool
ing_model
_T
=
TypeVar
(
"_T"
,
bound
=
type
[
nn
.
Module
])
_T
=
TypeVar
(
"_T"
,
bound
=
type
[
nn
.
Module
])
...
@@ -12,7 +12,7 @@ _T = TypeVar("_T", bound=type[nn.Module])
...
@@ -12,7 +12,7 @@ _T = TypeVar("_T", bound=type[nn.Module])
def
as_embedding_model
(
cls
:
_T
)
->
_T
:
def
as_embedding_model
(
cls
:
_T
)
->
_T
:
"""Subclass an existing vLLM model to support embeddings."""
"""Subclass an existing vLLM model to support embeddings."""
# Avoid modifying existing embedding models
# Avoid modifying existing embedding models
if
is_
embedd
ing_model
(
cls
):
if
is_
pool
ing_model
(
cls
):
return
cls
return
cls
# Lazy import
# Lazy import
...
@@ -23,7 +23,7 @@ def as_embedding_model(cls: _T) -> _T:
...
@@ -23,7 +23,7 @@ def as_embedding_model(cls: _T) -> _T:
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
class
ModelForEmbedding
(
cls
,
VllmModelFor
Embedd
ing
):
class
ModelForEmbedding
(
cls
,
VllmModelFor
Pool
ing
):
def
__init__
(
def
__init__
(
self
,
self
,
...
...
vllm/model_executor/models/interfaces.py
View file @
d2f058e7
...
@@ -7,7 +7,7 @@ from typing_extensions import TypeIs, TypeVar
...
@@ -7,7 +7,7 @@ from typing_extensions import TypeIs, TypeVar
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
supports_kw
from
vllm.utils
import
supports_kw
from
.interfaces_base
import
is_
embedd
ing_model
from
.interfaces_base
import
is_
pool
ing_model
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
...
@@ -389,4 +389,4 @@ def _supports_cross_encoding(
...
@@ -389,4 +389,4 @@ def _supports_cross_encoding(
def
supports_cross_encoding
(
def
supports_cross_encoding
(
model
:
Union
[
Type
[
object
],
object
],
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeIs
[
Type
[
SupportsCrossEncoding
]],
TypeIs
[
SupportsCrossEncoding
]]:
)
->
Union
[
TypeIs
[
Type
[
SupportsCrossEncoding
]],
TypeIs
[
SupportsCrossEncoding
]]:
return
is_
embedd
ing_model
(
model
)
and
_supports_cross_encoding
(
model
)
return
is_
pool
ing_model
(
model
)
and
_supports_cross_encoding
(
model
)
vllm/model_executor/models/interfaces_base.py
View file @
d2f058e7
...
@@ -141,7 +141,7 @@ def is_text_generation_model(
...
@@ -141,7 +141,7 @@ def is_text_generation_model(
@
runtime_checkable
@
runtime_checkable
class
VllmModelFor
Embedd
ing
(
VllmModel
[
C_co
,
T
],
Protocol
[
C_co
,
T
]):
class
VllmModelFor
Pool
ing
(
VllmModel
[
C_co
,
T
],
Protocol
[
C_co
,
T
]):
def
pooler
(
def
pooler
(
self
,
self
,
...
@@ -153,23 +153,22 @@ class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]):
...
@@ -153,23 +153,22 @@ class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]):
@
overload
@
overload
def
is_embedding_model
(
def
is_pooling_model
(
model
:
Type
[
object
])
->
TypeIs
[
Type
[
VllmModelForPooling
]]:
model
:
Type
[
object
])
->
TypeIs
[
Type
[
VllmModelForEmbedding
]]:
...
...
@
overload
@
overload
def
is_
embedd
ing_model
(
model
:
object
)
->
TypeIs
[
VllmModelFor
Embedd
ing
]:
def
is_
pool
ing_model
(
model
:
object
)
->
TypeIs
[
VllmModelFor
Pool
ing
]:
...
...
def
is_
embedd
ing_model
(
def
is_
pool
ing_model
(
model
:
Union
[
Type
[
object
],
object
],
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeIs
[
Type
[
VllmModelFor
Embedd
ing
]],
TypeIs
[
VllmModelFor
Embedd
ing
]]:
)
->
Union
[
TypeIs
[
Type
[
VllmModelFor
Pool
ing
]],
TypeIs
[
VllmModelFor
Pool
ing
]]:
if
not
is_vllm_model
(
model
):
if
not
is_vllm_model
(
model
):
return
False
return
False
if
isinstance
(
model
,
type
):
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
VllmModelFor
Embedd
ing
)
return
isinstance
(
model
,
VllmModelFor
Pool
ing
)
return
isinstance
(
model
,
VllmModelFor
Embedd
ing
)
return
isinstance
(
model
,
VllmModelFor
Pool
ing
)
vllm/model_executor/models/registry.py
View file @
d2f058e7
...
@@ -24,7 +24,7 @@ from .adapters import as_embedding_model
...
@@ -24,7 +24,7 @@ from .adapters import as_embedding_model
from
.interfaces
import
(
has_inner_state
,
is_attention_free
,
from
.interfaces
import
(
has_inner_state
,
is_attention_free
,
supports_cross_encoding
,
supports_multimodal
,
supports_cross_encoding
,
supports_multimodal
,
supports_pp
)
supports_pp
)
from
.interfaces_base
import
is_
embedd
ing_model
,
is_text_generation_model
from
.interfaces_base
import
is_
pool
ing_model
,
is_text_generation_model
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -211,7 +211,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
...
@@ -211,7 +211,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
class
_ModelInfo
:
class
_ModelInfo
:
architecture
:
str
architecture
:
str
is_text_generation_model
:
bool
is_text_generation_model
:
bool
is_
embedd
ing_model
:
bool
is_
pool
ing_model
:
bool
supports_cross_encoding
:
bool
supports_cross_encoding
:
bool
supports_multimodal
:
bool
supports_multimodal
:
bool
supports_pp
:
bool
supports_pp
:
bool
...
@@ -220,19 +220,19 @@ class _ModelInfo:
...
@@ -220,19 +220,19 @@ class _ModelInfo:
@
staticmethod
@
staticmethod
def
from_model_cls
(
model
:
Type
[
nn
.
Module
])
->
"_ModelInfo"
:
def
from_model_cls
(
model
:
Type
[
nn
.
Module
])
->
"_ModelInfo"
:
is_
embedd
ing_model_
=
is_
embedd
ing_model
(
model
)
is_
pool
ing_model_
=
is_
pool
ing_model
(
model
)
if
not
is_
embedd
ing_model_
:
if
not
is_
pool
ing_model_
:
try
:
try
:
as_embedding_model
(
model
)
as_embedding_model
(
model
)
except
Exception
:
except
Exception
:
pass
pass
else
:
else
:
is_
embedd
ing_model_
=
True
is_
pool
ing_model_
=
True
return
_ModelInfo
(
return
_ModelInfo
(
architecture
=
model
.
__name__
,
architecture
=
model
.
__name__
,
is_text_generation_model
=
is_text_generation_model
(
model
),
is_text_generation_model
=
is_text_generation_model
(
model
),
is_
embedd
ing_model
=
is_
embedd
ing_model_
,
is_
pool
ing_model
=
is_
pool
ing_model_
,
supports_cross_encoding
=
supports_cross_encoding
(
model
),
supports_cross_encoding
=
supports_cross_encoding
(
model
),
supports_multimodal
=
supports_multimodal
(
model
),
supports_multimodal
=
supports_multimodal
(
model
),
supports_pp
=
supports_pp
(
model
),
supports_pp
=
supports_pp
(
model
),
...
@@ -441,12 +441,12 @@ class _ModelRegistry:
...
@@ -441,12 +441,12 @@ class _ModelRegistry:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
is_text_generation_model
return
model_cls
.
is_text_generation_model
def
is_
embedd
ing_model
(
def
is_
pool
ing_model
(
self
,
self
,
architectures
:
Union
[
str
,
List
[
str
]],
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
)
->
bool
:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
is_
embedd
ing_model
return
model_cls
.
is_
pool
ing_model
def
is_cross_encoder_model
(
def
is_cross_encoder_model
(
self
,
self
,
...
...
vllm/outputs.py
View file @
d2f058e7
...
@@ -53,8 +53,8 @@ class CompletionOutput:
...
@@ -53,8 +53,8 @@ class CompletionOutput:
@
dataclass
@
dataclass
class
Embedd
ingOutput
:
class
Pool
ingOutput
:
"""The output data of one
completion
output of a request.
"""The output data of one
pooling
output of a request.
Args:
Args:
embedding: The embedding vector, which is a list of floats. The
embedding: The embedding vector, which is a list of floats. The
...
@@ -63,7 +63,7 @@ class EmbeddingOutput:
...
@@ -63,7 +63,7 @@ class EmbeddingOutput:
embedding
:
List
[
float
]
embedding
:
List
[
float
]
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"
Embedd
ingOutput("
return
(
f
"
Pool
ingOutput("
f
"embedding=
{
len
(
self
.
embedding
)
}
)"
)
f
"embedding=
{
len
(
self
.
embedding
)
}
)"
)
...
@@ -316,18 +316,18 @@ class RequestOutput:
...
@@ -316,18 +316,18 @@ class RequestOutput:
f
"multi_modal_placeholders=
{
self
.
multi_modal_placeholders
}
)"
)
f
"multi_modal_placeholders=
{
self
.
multi_modal_placeholders
}
)"
)
class
Embedd
ingRequestOutput
:
class
Pool
ingRequestOutput
:
"""
"""
The output data of a
n embedd
ing request to the LLM.
The output data of a
pool
ing request to the LLM.
Args:
Args:
request_id (str): A unique identifier for the
embedd
ing request.
request_id (str): A unique identifier for the
pool
ing request.
outputs (
Embedd
ingOutput): The
embedd
ing results for the given input.
outputs (
Pool
ingOutput): The
pool
ing results for the given input.
prompt_token_ids (List[int]): A list of token IDs used in the prompt.
prompt_token_ids (List[int]): A list of token IDs used in the prompt.
finished (bool): A flag indicating whether the
embedd
ing is completed.
finished (bool): A flag indicating whether the
pool
ing is completed.
"""
"""
def
__init__
(
self
,
request_id
:
str
,
outputs
:
"
Embedd
ingOutput"
,
def
__init__
(
self
,
request_id
:
str
,
outputs
:
"
Pool
ingOutput"
,
prompt_token_ids
:
List
[
int
],
finished
:
bool
):
prompt_token_ids
:
List
[
int
],
finished
:
bool
):
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
prompt_token_ids
=
prompt_token_ids
self
.
prompt_token_ids
=
prompt_token_ids
...
@@ -336,11 +336,11 @@ class EmbeddingRequestOutput:
...
@@ -336,11 +336,11 @@ class EmbeddingRequestOutput:
@
classmethod
@
classmethod
def
from_seq_group
(
cls
,
def
from_seq_group
(
cls
,
seq_group
:
'SequenceGroup'
)
->
"
Embedd
ingRequestOutput"
:
seq_group
:
'SequenceGroup'
)
->
"
Pool
ingRequestOutput"
:
if
seq_group
.
embeddings
is
None
:
if
seq_group
.
embeddings
is
None
:
raise
ValueError
(
raise
ValueError
(
"Embeddings are missing in seq_group for EmbeddingRequest."
)
"Embeddings are missing in seq_group for EmbeddingRequest."
)
output
=
Embedd
ingOutput
(
seq_group
.
embeddings
)
output
=
Pool
ingOutput
(
seq_group
.
embeddings
)
prompt_token_ids
=
seq_group
.
prompt_token_ids
prompt_token_ids
=
seq_group
.
prompt_token_ids
finished
=
seq_group
.
is_finished
()
finished
=
seq_group
.
is_finished
()
...
@@ -348,15 +348,15 @@ class EmbeddingRequestOutput:
...
@@ -348,15 +348,15 @@ class EmbeddingRequestOutput:
def
__repr__
(
self
):
def
__repr__
(
self
):
"""
"""
Returns a string representation of an
Embedd
ingRequestOutput instance.
Returns a string representation of an
Pool
ingRequestOutput instance.
The representation includes the request_id and the number of outputs,
The representation includes the request_id and the number of outputs,
providing a quick overview of the
embedd
ing request's results.
providing a quick overview of the
pool
ing request's results.
Returns:
Returns:
str: A string representation of the
Embedd
ingRequestOutput instance.
str: A string representation of the
Pool
ingRequestOutput instance.
"""
"""
return
(
f
"
Embedd
ingRequestOutput(request_id='
{
self
.
request_id
}
', "
return
(
f
"
Pool
ingRequestOutput(request_id='
{
self
.
request_id
}
', "
f
"outputs=
{
repr
(
self
.
outputs
)
}
, "
f
"outputs=
{
repr
(
self
.
outputs
)
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"finished=
{
self
.
finished
}
)"
)
f
"finished=
{
self
.
finished
}
)"
)
...
@@ -415,7 +415,30 @@ class RequestOutputFactory:
...
@@ -415,7 +415,30 @@ class RequestOutputFactory:
# Determine the type based on a condition, for example:
# Determine the type based on a condition, for example:
if
hasattr
(
seq_group
,
if
hasattr
(
seq_group
,
'embeddings'
)
and
seq_group
.
embeddings
is
not
None
:
'embeddings'
)
and
seq_group
.
embeddings
is
not
None
:
return
Embedd
ingRequestOutput
.
from_seq_group
(
seq_group
)
return
Pool
ingRequestOutput
.
from_seq_group
(
seq_group
)
else
:
else
:
return
RequestOutput
.
from_seq_group
(
seq_group
,
use_cache
,
return
RequestOutput
.
from_seq_group
(
seq_group
,
use_cache
,
seq_id_to_seq_group
)
seq_id_to_seq_group
)
def
__getattr__
(
name
:
str
):
import
warnings
if
name
==
"EmbeddingOutput"
:
msg
=
(
"EmbeddingOutput has been renamed to PoolingOutput. "
"The original name will be removed in an upcoming version."
)
warnings
.
warn
(
DeprecationWarning
(
msg
),
stacklevel
=
2
)
return
PoolingOutput
if
name
==
"EmbeddingRequestOutput"
:
msg
=
(
"EmbeddingRequestOutput has been renamed to "
"PoolingRequestOutput. "
"The original name will be removed in an upcoming version."
)
warnings
.
warn
(
DeprecationWarning
(
msg
),
stacklevel
=
2
)
return
PoolingRequestOutput
raise
AttributeError
(
f
"module
{
__name__
!
r
}
has no attribute
{
name
!
r
}
"
)
vllm/v1/engine/async_llm.py
View file @
d2f058e7
...
@@ -9,7 +9,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
...
@@ -9,7 +9,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
Embedd
ingRequestOutput
,
RequestOutput
from
vllm.outputs
import
Pool
ingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
...
@@ -133,7 +133,7 @@ class AsyncLLM(EngineClient):
...
@@ -133,7 +133,7 @@ class AsyncLLM(EngineClient):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
],
None
]:
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
Pool
ingRequestOutput
],
None
]:
"""Add new request to the AsyncLLM."""
"""Add new request to the AsyncLLM."""
if
self
.
detokenizer
.
is_request_active
(
request_id
):
if
self
.
detokenizer
.
is_request_active
(
request_id
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment