Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d2f058e7
Unverified
Commit
d2f058e7
authored
Dec 01, 2024
by
Cyrus Leung
Committed by
GitHub
Dec 01, 2024
Browse files
[Misc] Rename embedding classes to pooling (#10801)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
f877a7d1
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
153 additions
and
110 deletions
+153
-110
examples/offline_inference_embedding.py
examples/offline_inference_embedding.py
+1
-1
tests/entrypoints/llm/test_encode.py
tests/entrypoints/llm/test_encode.py
+3
-3
tests/models/test_registry.py
tests/models/test_registry.py
+2
-2
tests/worker/test_model_input.py
tests/worker/test_model_input.py
+2
-2
vllm/__init__.py
vllm/__init__.py
+27
-4
vllm/config.py
vllm/config.py
+1
-1
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+12
-12
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+4
-4
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+7
-7
vllm/engine/protocol.py
vllm/engine/protocol.py
+2
-3
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+15
-15
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+6
-6
vllm/entrypoints/openai/serving_score.py
vllm/entrypoints/openai/serving_score.py
+5
-5
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+5
-6
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+3
-3
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+2
-2
vllm/model_executor/models/interfaces_base.py
vllm/model_executor/models/interfaces_base.py
+7
-8
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+8
-8
vllm/outputs.py
vllm/outputs.py
+39
-16
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+2
-2
No files found.
examples/offline_inference_embedding.py
View file @
d2f058e7
...
...
@@ -10,7 +10,7 @@ prompts = [
# Create an LLM.
model
=
LLM
(
model
=
"intfloat/e5-mistral-7b-instruct"
,
enforce_eager
=
True
)
# Generate embedding. The output is a list of
Embedd
ingRequestOutputs.
# Generate embedding. The output is a list of
Pool
ingRequestOutputs.
outputs
=
model
.
encode
(
prompts
)
# Print the outputs.
for
output
in
outputs
:
...
...
tests/entrypoints/llm/test_encode.py
View file @
d2f058e7
...
...
@@ -3,7 +3,7 @@ from typing import List
import
pytest
from
vllm
import
LLM
,
Embedd
ingRequestOutput
,
PoolingParams
from
vllm
import
LLM
,
PoolingParams
,
Pool
ingRequestOutput
from
vllm.distributed
import
cleanup_dist_env_and_memory
MODEL_NAME
=
"intfloat/e5-mistral-7b-instruct"
...
...
@@ -43,8 +43,8 @@ def llm():
cleanup_dist_env_and_memory
()
def
assert_outputs_equal
(
o1
:
List
[
Embedd
ingRequestOutput
],
o2
:
List
[
Embedd
ingRequestOutput
]):
def
assert_outputs_equal
(
o1
:
List
[
Pool
ingRequestOutput
],
o2
:
List
[
Pool
ingRequestOutput
]):
assert
[
o
.
outputs
for
o
in
o1
]
==
[
o
.
outputs
for
o
in
o2
]
...
...
tests/models/test_registry.py
View file @
d2f058e7
...
...
@@ -3,7 +3,7 @@ import warnings
import
pytest
import
torch.cuda
from
vllm.model_executor.models
import
(
is_
embedd
ing_model
,
from
vllm.model_executor.models
import
(
is_
pool
ing_model
,
is_text_generation_model
,
supports_multimodal
)
from
vllm.model_executor.models.adapters
import
as_embedding_model
...
...
@@ -31,7 +31,7 @@ def test_registry_imports(model_arch):
# All vLLM models should be convertible to an embedding model
embed_model
=
as_embedding_model
(
model_cls
)
assert
is_
embedd
ing_model
(
embed_model
)
assert
is_
pool
ing_model
(
embed_model
)
if
model_arch
in
_MULTIMODAL_MODELS
:
assert
supports_multimodal
(
model_cls
)
...
...
tests/worker/test_model_input.py
View file @
d2f058e7
...
...
@@ -8,10 +8,10 @@ from vllm.attention.backends.abstract import AttentionBackend
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.worker.embedding_model_runner
import
(
ModelInputForGPUWithPoolingMetadata
)
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
from
vllm.worker.multi_step_model_runner
import
StatefulModelInput
from
vllm.worker.pooling_model_runner
import
(
ModelInputForGPUWithPoolingMetadata
)
class
MockAttentionBackend
(
AttentionBackend
):
...
...
vllm/__init__.py
View file @
d2f058e7
...
...
@@ -7,8 +7,8 @@ from vllm.entrypoints.llm import LLM
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
PromptType
,
TextPrompt
,
TokensPrompt
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.outputs
import
(
CompletionOutput
,
Embedd
ingOutput
,
Embedd
ingRequestOutput
,
RequestOutput
)
from
vllm.outputs
import
(
CompletionOutput
,
Pool
ingOutput
,
Pool
ingRequestOutput
,
RequestOutput
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
...
...
@@ -25,8 +25,8 @@ __all__ = [
"SamplingParams"
,
"RequestOutput"
,
"CompletionOutput"
,
"
Embedd
ingOutput"
,
"
Embedd
ingRequestOutput"
,
"
Pool
ingOutput"
,
"
Pool
ingRequestOutput"
,
"LLMEngine"
,
"EngineArgs"
,
"AsyncLLMEngine"
,
...
...
@@ -34,3 +34,26 @@ __all__ = [
"initialize_ray_cluster"
,
"PoolingParams"
,
]
def
__getattr__
(
name
:
str
):
import
warnings
if
name
==
"EmbeddingOutput"
:
msg
=
(
"EmbeddingOutput has been renamed to PoolingOutput. "
"The original name will be removed in an upcoming version."
)
warnings
.
warn
(
DeprecationWarning
(
msg
),
stacklevel
=
2
)
return
PoolingOutput
if
name
==
"EmbeddingRequestOutput"
:
msg
=
(
"EmbeddingRequestOutput has been renamed to "
"PoolingRequestOutput. "
"The original name will be removed in an upcoming version."
)
warnings
.
warn
(
DeprecationWarning
(
msg
),
stacklevel
=
2
)
return
PoolingRequestOutput
raise
AttributeError
(
f
"module
{
__name__
!
r
}
has no attribute
{
name
!
r
}
"
)
vllm/config.py
View file @
d2f058e7
...
...
@@ -359,7 +359,7 @@ class ModelConfig:
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate"
:
ModelRegistry
.
is_text_generation_model
(
architectures
),
"embedding"
:
ModelRegistry
.
is_
embedd
ing_model
(
architectures
),
"embedding"
:
ModelRegistry
.
is_
pool
ing_model
(
architectures
),
}
supported_tasks_lst
:
List
[
_Task
]
=
[
task
for
task
,
is_supported
in
task_support
.
items
()
if
is_supported
...
...
vllm/engine/async_llm_engine.py
View file @
d2f058e7
...
...
@@ -25,7 +25,7 @@ from vllm.lora.request import LoRARequest
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
Embedd
ingRequestOutput
,
RequestOutput
from
vllm.outputs
import
Pool
ingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
...
...
@@ -74,7 +74,7 @@ STOP_ITERATION = Exception() # Sentinel
class
AsyncStream
:
"""A stream of RequestOutputs or
Embedd
ingRequestOutputs for a request
"""A stream of RequestOutputs or
Pool
ingRequestOutputs for a request
that can be iterated over asynchronously via an async generator."""
def
__init__
(
self
,
request_id
:
str
,
cancel
:
Callable
[[
str
],
None
])
->
None
:
...
...
@@ -83,7 +83,7 @@ class AsyncStream:
self
.
_queue
:
asyncio
.
Queue
=
asyncio
.
Queue
()
self
.
_finished
=
False
def
put
(
self
,
item
:
Union
[
RequestOutput
,
Embedd
ingRequestOutput
,
def
put
(
self
,
item
:
Union
[
RequestOutput
,
Pool
ingRequestOutput
,
Exception
])
->
None
:
if
not
self
.
_finished
:
self
.
_queue
.
put_nowait
(
item
)
...
...
@@ -103,7 +103,7 @@ class AsyncStream:
async
def
generator
(
self
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
],
None
]:
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
Pool
ingRequestOutput
],
None
]:
try
:
while
True
:
result
=
await
self
.
_queue
.
get
()
...
...
@@ -154,7 +154,7 @@ class RequestTracker:
def
process_request_output
(
self
,
request_output
:
Union
[
RequestOutput
,
Embedd
ingRequestOutput
],
Pool
ingRequestOutput
],
*
,
verbose
:
bool
=
False
)
->
None
:
"""Process a request output from the engine."""
...
...
@@ -265,7 +265,7 @@ class _AsyncLLMEngine(LLMEngine):
async
def
step_async
(
self
,
virtual_engine
:
int
)
->
List
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
]]:
)
->
List
[
Union
[
RequestOutput
,
Pool
ingRequestOutput
]]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
...
...
@@ -907,7 +907,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
Coroutine
[
None
,
None
,
AsyncGenerator
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
],
None
]]:
RequestOutput
,
Pool
ingRequestOutput
],
None
]]:
...
@
overload
...
...
@@ -922,7 +922,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
Coroutine
[
None
,
None
,
AsyncGenerator
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
],
None
]]:
RequestOutput
,
Pool
ingRequestOutput
],
None
]]:
...
@
deprecate_kwargs
(
...
...
@@ -941,7 +941,7 @@ class AsyncLLMEngine(EngineClient):
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
],
None
]:
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
Pool
ingRequestOutput
],
None
]:
if
inputs
is
not
None
:
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
...
...
@@ -1070,7 +1070,7 @@ class AsyncLLMEngine(EngineClient):
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
]:
)
->
AsyncGenerator
[
Pool
ingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
...
...
@@ -1088,7 +1088,7 @@ class AsyncLLMEngine(EngineClient):
Only applicable with priority scheduling.
Yields:
The output `
Embedd
ingRequestOutput` objects from the LLMEngine
The output `
Pool
ingRequestOutput` objects from the LLMEngine
for the request.
Details:
...
...
@@ -1141,7 +1141,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers
=
trace_headers
,
priority
=
priority
,
):
yield
LLMEngine
.
validate_output
(
output
,
Embedd
ingRequestOutput
)
yield
LLMEngine
.
validate_output
(
output
,
Pool
ingRequestOutput
)
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
"""Abort a request.
...
...
vllm/engine/llm_engine.py
View file @
d2f058e7
...
...
@@ -40,7 +40,7 @@ from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.outputs
import
(
Embedd
ingRequestOutput
,
RequestOutput
,
from
vllm.outputs
import
(
Pool
ingRequestOutput
,
RequestOutput
,
RequestOutputFactory
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
...
...
@@ -80,7 +80,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
_G
=
TypeVar
(
"_G"
,
bound
=
BaseTokenizerGroup
,
default
=
BaseTokenizerGroup
)
_O
=
TypeVar
(
"_O"
,
RequestOutput
,
Embedd
ingRequestOutput
)
_O
=
TypeVar
(
"_O"
,
RequestOutput
,
Pool
ingRequestOutput
)
@
dataclass
...
...
@@ -112,7 +112,7 @@ class SchedulerContext:
def
__init__
(
self
,
multi_step_stream_outputs
:
bool
=
False
):
self
.
output_queue
:
Deque
[
OutputData
]
=
deque
()
self
.
request_outputs
:
List
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
]]
=
[]
Pool
ingRequestOutput
]]
=
[]
self
.
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
self
.
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
...
...
@@ -1314,7 +1314,7 @@ class LLMEngine:
else
:
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
]]:
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
Pool
ingRequestOutput
]]:
"""Performs one decoding iteration and returns newly generated results.
.. figure:: https://i.imgur.com/sv2HssD.png
...
...
vllm/engine/multiprocessing/client.py
View file @
d2f058e7
...
...
@@ -35,7 +35,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
Embedd
ingRequestOutput
,
RequestOutput
from
vllm.outputs
import
Pool
ingRequestOutput
,
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
...
...
@@ -495,7 +495,7 @@ class MQLLMEngineClient(EngineClient):
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
]:
)
->
AsyncGenerator
[
Pool
ingRequestOutput
,
None
]:
...
@
overload
...
...
@@ -507,7 +507,7 @@ class MQLLMEngineClient(EngineClient):
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
]:
)
->
AsyncGenerator
[
Pool
ingRequestOutput
,
None
]:
...
@
deprecate_kwargs
(
...
...
@@ -524,7 +524,7 @@ class MQLLMEngineClient(EngineClient):
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
# DEPRECATED
)
->
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
]:
)
->
AsyncGenerator
[
Pool
ingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
...
...
@@ -540,7 +540,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: OpenTelemetry trace headers.
Yields:
The output `
Embedd
ingRequestOutput` objects from the LLMEngine
The output `
Pool
ingRequestOutput` objects from the LLMEngine
for the request.
"""
if
inputs
is
not
None
:
...
...
@@ -549,7 +549,7 @@ class MQLLMEngineClient(EngineClient):
and
request_id
is
not
None
)
return
cast
(
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
],
AsyncGenerator
[
Pool
ingRequestOutput
,
None
],
self
.
_process_request
(
prompt
,
pooling_params
,
request_id
,
...
...
@@ -567,7 +567,7 @@ class MQLLMEngineClient(EngineClient):
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
Union
[
AsyncGenerator
[
RequestOutput
,
None
],
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
]]:
Pool
ingRequestOutput
,
None
]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out.
...
...
vllm/engine/protocol.py
View file @
d2f058e7
...
...
@@ -11,8 +11,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
(
CompletionOutput
,
EmbeddingRequestOutput
,
RequestOutput
)
from
vllm.outputs
import
CompletionOutput
,
PoolingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
...
...
@@ -209,7 +208,7 @@ class EngineClient(ABC):
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
]:
)
->
AsyncGenerator
[
Pool
ingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model."""
...
...
...
vllm/entrypoints/llm.py
View file @
d2f058e7
...
...
@@ -26,7 +26,7 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding.guided_fields
import
(
GuidedDecodingRequest
,
LLMGuidedOptions
)
from
vllm.outputs
import
Embedd
ingRequestOutput
,
RequestOutput
from
vllm.outputs
import
Pool
ingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
(
BeamSearchParams
,
GuidedDecodingParams
,
...
...
@@ -679,7 +679,7 @@ class LLM:
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
)
->
List
[
Embedd
ingRequestOutput
]:
)
->
List
[
Pool
ingRequestOutput
]:
...
@
overload
# LEGACY: multi (prompt + optional token ids)
...
...
@@ -691,7 +691,7 @@ class LLM:
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
)
->
List
[
Embedd
ingRequestOutput
]:
)
->
List
[
Pool
ingRequestOutput
]:
...
@
overload
# LEGACY: single (token ids + optional prompt)
...
...
@@ -704,7 +704,7 @@ class LLM:
prompt_token_ids
:
List
[
int
],
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
)
->
List
[
Embedd
ingRequestOutput
]:
)
->
List
[
Pool
ingRequestOutput
]:
...
@
overload
# LEGACY: multi (token ids + optional prompt)
...
...
@@ -717,7 +717,7 @@ class LLM:
prompt_token_ids
:
List
[
List
[
int
]],
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
)
->
List
[
Embedd
ingRequestOutput
]:
)
->
List
[
Pool
ingRequestOutput
]:
...
@
overload
# LEGACY: single or multi token ids [pos-only]
...
...
@@ -728,7 +728,7 @@ class LLM:
prompt_token_ids
:
Union
[
List
[
int
],
List
[
List
[
int
]]],
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
)
->
List
[
Embedd
ingRequestOutput
]:
)
->
List
[
Pool
ingRequestOutput
]:
...
@
overload
...
...
@@ -741,7 +741,7 @@ class LLM:
Sequence
[
PoolingParams
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
)
->
List
[
Embedd
ingRequestOutput
]:
)
->
List
[
Pool
ingRequestOutput
]:
...
@
deprecate_kwargs
(
...
...
@@ -759,7 +759,7 @@ class LLM:
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
List
[
Embedd
ingRequestOutput
]:
)
->
List
[
Pool
ingRequestOutput
]:
"""Generates the completions for the input prompts.
This class automatically batches the given prompts, considering
...
...
@@ -778,7 +778,7 @@ class LLM:
generation, if any.
Returns:
A list of ``
Embedd
ingRequestOutput`` objects containing the
A list of ``
Pool
ingRequestOutput`` objects containing the
generated embeddings in the same order as the input prompts.
Note:
...
...
@@ -821,7 +821,7 @@ class LLM:
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
Embedd
ingRequestOutput
)
Pool
ingRequestOutput
)
def
score
(
self
,
...
...
@@ -832,7 +832,7 @@ class LLM:
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
List
[
Embedd
ingRequestOutput
]:
)
->
List
[
Pool
ingRequestOutput
]:
"""Generates similarity scores for all pairs <text,text_pair>.
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
...
...
@@ -854,7 +854,7 @@ class LLM:
generation, if any.
Returns:
A list of ``
Embedd
ingRequestOutput`` objects containing the
A list of ``
Pool
ingRequestOutput`` objects containing the
generated scores in the same order as the input prompts.
"""
task
=
self
.
llm_engine
.
model_config
.
task
...
...
@@ -943,7 +943,7 @@ class LLM:
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
Embedd
ingRequestOutput
)
Pool
ingRequestOutput
)
def
start_profile
(
self
)
->
None
:
self
.
llm_engine
.
start_profile
()
...
...
@@ -1085,7 +1085,7 @@ class LLM:
def
_run_engine
(
self
,
*
,
use_tqdm
:
bool
)
->
List
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
]]:
)
->
List
[
Union
[
RequestOutput
,
Pool
ingRequestOutput
]]:
# Initialize tqdm.
if
use_tqdm
:
num_requests
=
self
.
llm_engine
.
get_num_unfinished_requests
()
...
...
@@ -1098,7 +1098,7 @@ class LLM:
)
# Run the engine.
outputs
:
List
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
]]
=
[]
outputs
:
List
[
Union
[
RequestOutput
,
Pool
ingRequestOutput
]]
=
[]
total_in_toks
=
0
total_out_toks
=
0
while
self
.
llm_engine
.
has_unfinished_requests
():
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
d2f058e7
...
...
@@ -18,14 +18,14 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
ErrorResponse
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
,
OpenAIServing
from
vllm.logger
import
init_logger
from
vllm.outputs
import
Embedd
ingOutput
,
Embedd
ingRequestOutput
from
vllm.outputs
import
Pool
ingOutput
,
Pool
ingRequestOutput
from
vllm.utils
import
merge_async_iterators
,
random_uuid
logger
=
init_logger
(
__name__
)
def
_get_embedding
(
output
:
Embedd
ingOutput
,
output
:
Pool
ingOutput
,
encoding_format
:
Literal
[
"float"
,
"base64"
],
)
->
Union
[
List
[
float
],
str
]:
if
encoding_format
==
"float"
:
...
...
@@ -40,7 +40,7 @@ def _get_embedding(
def
request_output_to_embedding_response
(
final_res_batch
:
List
[
Embedd
ingRequestOutput
],
request_id
:
str
,
final_res_batch
:
List
[
Pool
ingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
encoding_format
:
Literal
[
"float"
,
"base64"
])
->
EmbeddingResponse
:
data
:
List
[
EmbeddingResponseData
]
=
[]
...
...
@@ -169,7 +169,7 @@ class OpenAIServingEmbedding(OpenAIServing):
return
self
.
create_error_response
(
str
(
e
))
# Schedule the request and get the result generator.
generators
:
List
[
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
]]
=
[]
generators
:
List
[
AsyncGenerator
[
Pool
ingRequestOutput
,
None
]]
=
[]
try
:
pooling_params
=
request
.
to_pooling_params
()
...
...
@@ -207,7 +207,7 @@ class OpenAIServingEmbedding(OpenAIServing):
num_prompts
=
len
(
engine_prompts
)
# Non-streaming response
final_res_batch
:
List
[
Optional
[
Embedd
ingRequestOutput
]]
final_res_batch
:
List
[
Optional
[
Pool
ingRequestOutput
]]
final_res_batch
=
[
None
]
*
num_prompts
try
:
async
for
i
,
res
in
result_generator
:
...
...
@@ -215,7 +215,7 @@ class OpenAIServingEmbedding(OpenAIServing):
assert
all
(
final_res
is
not
None
for
final_res
in
final_res_batch
)
final_res_batch_checked
=
cast
(
List
[
Embedd
ingRequestOutput
],
final_res_batch_checked
=
cast
(
List
[
Pool
ingRequestOutput
],
final_res_batch
)
response
=
request_output_to_embedding_response
(
...
...
vllm/entrypoints/openai/serving_score.py
View file @
d2f058e7
...
...
@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
,
OpenAIServing
from
vllm.inputs.data
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.outputs
import
Embedd
ingRequestOutput
from
vllm.outputs
import
Pool
ingRequestOutput
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
from
vllm.utils
import
make_async
,
merge_async_iterators
,
random_uuid
...
...
@@ -21,7 +21,7 @@ logger = init_logger(__name__)
def
request_output_to_score_response
(
final_res_batch
:
List
[
Embedd
ingRequestOutput
],
request_id
:
str
,
final_res_batch
:
List
[
Pool
ingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
)
->
ScoreResponse
:
data
:
List
[
ScoreResponseData
]
=
[]
score
=
None
...
...
@@ -133,7 +133,7 @@ class OpenAIServingScores(OpenAIServing):
return
self
.
create_error_response
(
str
(
e
))
# Schedule the request and get the result generator.
generators
:
List
[
AsyncGenerator
[
Embedd
ingRequestOutput
,
None
]]
=
[]
generators
:
List
[
AsyncGenerator
[
Pool
ingRequestOutput
,
None
]]
=
[]
input_pairs
=
make_pairs
(
request
.
text_1
,
request
.
text_2
)
...
...
@@ -194,7 +194,7 @@ class OpenAIServingScores(OpenAIServing):
num_prompts
=
len
(
engine_prompts
)
# Non-streaming response
final_res_batch
:
List
[
Optional
[
Embedd
ingRequestOutput
]]
final_res_batch
:
List
[
Optional
[
Pool
ingRequestOutput
]]
final_res_batch
=
[
None
]
*
num_prompts
try
:
...
...
@@ -203,7 +203,7 @@ class OpenAIServingScores(OpenAIServing):
assert
all
(
final_res
is
not
None
for
final_res
in
final_res_batch
)
final_res_batch_checked
=
cast
(
List
[
Embedd
ingRequestOutput
],
final_res_batch_checked
=
cast
(
List
[
Pool
ingRequestOutput
],
final_res_batch
)
response
=
request_output_to_score_response
(
...
...
vllm/model_executor/models/__init__.py
View file @
d2f058e7
from
.interfaces
import
(
HasInnerState
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
,
has_inner_state
,
supports_lora
,
supports_multimodal
,
supports_pp
)
from
.interfaces_base
import
(
VllmModelForEmbedding
,
VllmModelForTextGeneration
,
is_embedding_model
,
is_text_generation_model
)
from
.interfaces_base
import
(
VllmModelForPooling
,
VllmModelForTextGeneration
,
is_pooling_model
,
is_text_generation_model
)
from
.registry
import
ModelRegistry
__all__
=
[
"ModelRegistry"
,
"VllmModelFor
Embedd
ing"
,
"is_
embedd
ing_model"
,
"VllmModelFor
Pool
ing"
,
"is_
pool
ing_model"
,
"VllmModelForTextGeneration"
,
"is_text_generation_model"
,
"HasInnerState"
,
...
...
vllm/model_executor/models/adapters.py
View file @
d2f058e7
...
...
@@ -4,7 +4,7 @@ from typing import Any, TypeVar
import
torch
import
torch.nn
as
nn
from
.interfaces_base
import
VllmModelFor
Embedd
ing
,
is_
embedd
ing_model
from
.interfaces_base
import
VllmModelFor
Pool
ing
,
is_
pool
ing_model
_T
=
TypeVar
(
"_T"
,
bound
=
type
[
nn
.
Module
])
...
...
@@ -12,7 +12,7 @@ _T = TypeVar("_T", bound=type[nn.Module])
def
as_embedding_model
(
cls
:
_T
)
->
_T
:
"""Subclass an existing vLLM model to support embeddings."""
# Avoid modifying existing embedding models
if
is_
embedd
ing_model
(
cls
):
if
is_
pool
ing_model
(
cls
):
return
cls
# Lazy import
...
...
@@ -23,7 +23,7 @@ def as_embedding_model(cls: _T) -> _T:
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
class
ModelForEmbedding
(
cls
,
VllmModelFor
Embedd
ing
):
class
ModelForEmbedding
(
cls
,
VllmModelFor
Pool
ing
):
def
__init__
(
self
,
...
...
vllm/model_executor/models/interfaces.py
View file @
d2f058e7
...
...
@@ -7,7 +7,7 @@ from typing_extensions import TypeIs, TypeVar
from
vllm.logger
import
init_logger
from
vllm.utils
import
supports_kw
from
.interfaces_base
import
is_
embedd
ing_model
from
.interfaces_base
import
is_
pool
ing_model
if
TYPE_CHECKING
:
from
vllm.attention
import
AttentionMetadata
...
...
@@ -389,4 +389,4 @@ def _supports_cross_encoding(
def
supports_cross_encoding
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeIs
[
Type
[
SupportsCrossEncoding
]],
TypeIs
[
SupportsCrossEncoding
]]:
return
is_
embedd
ing_model
(
model
)
and
_supports_cross_encoding
(
model
)
return
is_
pool
ing_model
(
model
)
and
_supports_cross_encoding
(
model
)
vllm/model_executor/models/interfaces_base.py
View file @
d2f058e7
...
...
@@ -141,7 +141,7 @@ def is_text_generation_model(
@
runtime_checkable
class
VllmModelFor
Embedd
ing
(
VllmModel
[
C_co
,
T
],
Protocol
[
C_co
,
T
]):
class
VllmModelFor
Pool
ing
(
VllmModel
[
C_co
,
T
],
Protocol
[
C_co
,
T
]):
def
pooler
(
self
,
...
...
@@ -153,23 +153,22 @@ class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]):
@
overload
def
is_embedding_model
(
model
:
Type
[
object
])
->
TypeIs
[
Type
[
VllmModelForEmbedding
]]:
def
is_pooling_model
(
model
:
Type
[
object
])
->
TypeIs
[
Type
[
VllmModelForPooling
]]:
...
@
overload
def
is_
embedd
ing_model
(
model
:
object
)
->
TypeIs
[
VllmModelFor
Embedd
ing
]:
def
is_
pool
ing_model
(
model
:
object
)
->
TypeIs
[
VllmModelFor
Pool
ing
]:
...
def
is_
embedd
ing_model
(
def
is_
pool
ing_model
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeIs
[
Type
[
VllmModelFor
Embedd
ing
]],
TypeIs
[
VllmModelFor
Embedd
ing
]]:
)
->
Union
[
TypeIs
[
Type
[
VllmModelFor
Pool
ing
]],
TypeIs
[
VllmModelFor
Pool
ing
]]:
if
not
is_vllm_model
(
model
):
return
False
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
VllmModelFor
Embedd
ing
)
return
isinstance
(
model
,
VllmModelFor
Pool
ing
)
return
isinstance
(
model
,
VllmModelFor
Embedd
ing
)
return
isinstance
(
model
,
VllmModelFor
Pool
ing
)
vllm/model_executor/models/registry.py
View file @
d2f058e7
...
...
@@ -24,7 +24,7 @@ from .adapters import as_embedding_model
from
.interfaces
import
(
has_inner_state
,
is_attention_free
,
supports_cross_encoding
,
supports_multimodal
,
supports_pp
)
from
.interfaces_base
import
is_
embedd
ing_model
,
is_text_generation_model
from
.interfaces_base
import
is_
pool
ing_model
,
is_text_generation_model
logger
=
init_logger
(
__name__
)
...
...
@@ -211,7 +211,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
class
_ModelInfo
:
architecture
:
str
is_text_generation_model
:
bool
is_
embedd
ing_model
:
bool
is_
pool
ing_model
:
bool
supports_cross_encoding
:
bool
supports_multimodal
:
bool
supports_pp
:
bool
...
...
@@ -220,19 +220,19 @@ class _ModelInfo:
@
staticmethod
def
from_model_cls
(
model
:
Type
[
nn
.
Module
])
->
"_ModelInfo"
:
is_
embedd
ing_model_
=
is_
embedd
ing_model
(
model
)
if
not
is_
embedd
ing_model_
:
is_
pool
ing_model_
=
is_
pool
ing_model
(
model
)
if
not
is_
pool
ing_model_
:
try
:
as_embedding_model
(
model
)
except
Exception
:
pass
else
:
is_
embedd
ing_model_
=
True
is_
pool
ing_model_
=
True
return
_ModelInfo
(
architecture
=
model
.
__name__
,
is_text_generation_model
=
is_text_generation_model
(
model
),
is_
embedd
ing_model
=
is_
embedd
ing_model_
,
is_
pool
ing_model
=
is_
pool
ing_model_
,
supports_cross_encoding
=
supports_cross_encoding
(
model
),
supports_multimodal
=
supports_multimodal
(
model
),
supports_pp
=
supports_pp
(
model
),
...
...
@@ -441,12 +441,12 @@ class _ModelRegistry:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
is_text_generation_model
def
is_
embedd
ing_model
(
def
is_
pool
ing_model
(
self
,
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
is_
embedd
ing_model
return
model_cls
.
is_
pool
ing_model
def
is_cross_encoder_model
(
self
,
...
...
vllm/outputs.py
View file @
d2f058e7
...
...
@@ -53,8 +53,8 @@ class CompletionOutput:
@
dataclass
class
Embedd
ingOutput
:
"""The output data of one
completion
output of a request.
class
Pool
ingOutput
:
"""The output data of one
pooling
output of a request.
Args:
embedding: The embedding vector, which is a list of floats. The
...
...
@@ -63,7 +63,7 @@ class EmbeddingOutput:
embedding
:
List
[
float
]
def
__repr__
(
self
)
->
str
:
return
(
f
"
Embedd
ingOutput("
return
(
f
"
Pool
ingOutput("
f
"embedding=
{
len
(
self
.
embedding
)
}
)"
)
...
...
@@ -316,18 +316,18 @@ class RequestOutput:
f
"multi_modal_placeholders=
{
self
.
multi_modal_placeholders
}
)"
)
class
Embedd
ingRequestOutput
:
class
Pool
ingRequestOutput
:
"""
The output data of a
n embedd
ing request to the LLM.
The output data of a
pool
ing request to the LLM.
Args:
request_id (str): A unique identifier for the
embedd
ing request.
outputs (
Embedd
ingOutput): The
embedd
ing results for the given input.
request_id (str): A unique identifier for the
pool
ing request.
outputs (
Pool
ingOutput): The
pool
ing results for the given input.
prompt_token_ids (List[int]): A list of token IDs used in the prompt.
finished (bool): A flag indicating whether the
embedd
ing is completed.
finished (bool): A flag indicating whether the
pool
ing is completed.
"""
def
__init__
(
self
,
request_id
:
str
,
outputs
:
"
Embedd
ingOutput"
,
def
__init__
(
self
,
request_id
:
str
,
outputs
:
"
Pool
ingOutput"
,
prompt_token_ids
:
List
[
int
],
finished
:
bool
):
self
.
request_id
=
request_id
self
.
prompt_token_ids
=
prompt_token_ids
...
...
@@ -336,11 +336,11 @@ class EmbeddingRequestOutput:
@
classmethod
def
from_seq_group
(
cls
,
seq_group
:
'SequenceGroup'
)
->
"
Embedd
ingRequestOutput"
:
seq_group
:
'SequenceGroup'
)
->
"
Pool
ingRequestOutput"
:
if
seq_group
.
embeddings
is
None
:
raise
ValueError
(
"Embeddings are missing in seq_group for EmbeddingRequest."
)
output
=
Embedd
ingOutput
(
seq_group
.
embeddings
)
output
=
Pool
ingOutput
(
seq_group
.
embeddings
)
prompt_token_ids
=
seq_group
.
prompt_token_ids
finished
=
seq_group
.
is_finished
()
...
...
@@ -348,15 +348,15 @@ class EmbeddingRequestOutput:
def
__repr__
(
self
):
"""
Returns a string representation of an
Embedd
ingRequestOutput instance.
Returns a string representation of an
Pool
ingRequestOutput instance.
The representation includes the request_id and the number of outputs,
providing a quick overview of the
embedd
ing request's results.
providing a quick overview of the
pool
ing request's results.
Returns:
str: A string representation of the
Embedd
ingRequestOutput instance.
str: A string representation of the
Pool
ingRequestOutput instance.
"""
return
(
f
"
Embedd
ingRequestOutput(request_id='
{
self
.
request_id
}
', "
return
(
f
"
Pool
ingRequestOutput(request_id='
{
self
.
request_id
}
', "
f
"outputs=
{
repr
(
self
.
outputs
)
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"finished=
{
self
.
finished
}
)"
)
...
...
@@ -415,7 +415,30 @@ class RequestOutputFactory:
# Determine the type based on a condition, for example:
if
hasattr
(
seq_group
,
'embeddings'
)
and
seq_group
.
embeddings
is
not
None
:
return
Embedd
ingRequestOutput
.
from_seq_group
(
seq_group
)
return
Pool
ingRequestOutput
.
from_seq_group
(
seq_group
)
else
:
return
RequestOutput
.
from_seq_group
(
seq_group
,
use_cache
,
seq_id_to_seq_group
)
def
__getattr__
(
name
:
str
):
import
warnings
if
name
==
"EmbeddingOutput"
:
msg
=
(
"EmbeddingOutput has been renamed to PoolingOutput. "
"The original name will be removed in an upcoming version."
)
warnings
.
warn
(
DeprecationWarning
(
msg
),
stacklevel
=
2
)
return
PoolingOutput
if
name
==
"EmbeddingRequestOutput"
:
msg
=
(
"EmbeddingRequestOutput has been renamed to "
"PoolingRequestOutput. "
"The original name will be removed in an upcoming version."
)
warnings
.
warn
(
DeprecationWarning
(
msg
),
stacklevel
=
2
)
return
PoolingRequestOutput
raise
AttributeError
(
f
"module
{
__name__
!
r
}
has no attribute
{
name
!
r
}
"
)
vllm/v1/engine/async_llm.py
View file @
d2f058e7
...
...
@@ -9,7 +9,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
Embedd
ingRequestOutput
,
RequestOutput
from
vllm.outputs
import
Pool
ingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
...
...
@@ -133,7 +133,7 @@ class AsyncLLM(EngineClient):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
Embedd
ingRequestOutput
],
None
]:
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
Pool
ingRequestOutput
],
None
]:
"""Add new request to the AsyncLLM."""
if
self
.
detokenizer
.
is_request_active
(
request_id
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment