Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
28e1299e
Unverified
Commit
28e1299e
authored
Sep 26, 2024
by
Cyrus Leung
Committed by
GitHub
Sep 25, 2024
Browse files
rename PromptInputs and inputs with backward compatibility (#8760)
parent
0c4d2ad5
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
395 additions
and
202 deletions
+395
-202
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+4
-4
docs/source/dev/multimodal/multimodal_index.rst
docs/source/dev/multimodal/multimodal_index.rst
+1
-1
docs/source/dev/offline_inference/llm_inputs.rst
docs/source/dev/offline_inference/llm_inputs.rst
+1
-1
docs/source/models/vlm.rst
docs/source/models/vlm.rst
+1
-1
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+5
-3
tests/entrypoints/llm/test_encode.py
tests/entrypoints/llm/test_encode.py
+0
-34
tests/entrypoints/llm/test_generate.py
tests/entrypoints/llm/test_generate.py
+0
-37
tests/mq_llm_engine/test_error_handling.py
tests/mq_llm_engine/test_error_handling.py
+6
-6
tests/mq_llm_engine/utils.py
tests/mq_llm_engine/utils.py
+1
-1
vllm/__init__.py
vllm/__init__.py
+2
-2
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+92
-18
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+45
-7
vllm/engine/multiprocessing/__init__.py
vllm/engine/multiprocessing/__init__.py
+58
-3
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+82
-13
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+1
-1
vllm/engine/protocol.py
vllm/engine/protocol.py
+4
-4
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+35
-33
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+17
-3
vllm/inputs/data.py
vllm/inputs/data.py
+29
-19
vllm/inputs/parse.py
vllm/inputs/parse.py
+11
-11
No files found.
benchmarks/benchmark_latency.py
View file @
28e1299e
...
...
@@ -11,7 +11,7 @@ from tqdm import tqdm
from
vllm
import
LLM
,
SamplingParams
from
vllm.engine.arg_utils
import
DEVICE_OPTIONS
,
EngineArgs
from
vllm.inputs
import
Prompt
Inputs
from
vllm.inputs
import
Prompt
Type
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.utils
import
FlexibleArgumentParser
...
...
@@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids
=
np
.
random
.
randint
(
10000
,
size
=
(
args
.
batch_size
,
args
.
input_len
))
dummy_
inpu
ts
:
List
[
Prompt
Inputs
]
=
[{
dummy_
promp
ts
:
List
[
Prompt
Type
]
=
[{
"prompt_token_ids"
:
batch
}
for
batch
in
dummy_prompt_token_ids
.
tolist
()]
...
...
@@ -74,13 +74,13 @@ def main(args: argparse.Namespace):
],
on_trace_ready
=
torch
.
profiler
.
tensorboard_trace_handler
(
str
(
profile_dir
)))
as
p
:
llm
.
generate
(
dummy_
inpu
ts
,
llm
.
generate
(
dummy_
promp
ts
,
sampling_params
=
sampling_params
,
use_tqdm
=
False
)
print
(
p
.
key_averages
())
else
:
start_time
=
time
.
perf_counter
()
llm
.
generate
(
dummy_
inpu
ts
,
llm
.
generate
(
dummy_
promp
ts
,
sampling_params
=
sampling_params
,
use_tqdm
=
False
)
end_time
=
time
.
perf_counter
()
...
...
docs/source/dev/multimodal/multimodal_index.rst
View file @
28e1299e
...
...
@@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.Prompt
Inputs
`.
via the ``multi_modal_data`` field in :class:`vllm.inputs.Prompt
Type
`.
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`.
...
...
docs/source/dev/offline_inference/llm_inputs.rst
View file @
28e1299e
LLM Inputs
==========
.. autodata:: vllm.inputs.Prompt
Inputs
.. autodata:: vllm.inputs.Prompt
Type
.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
...
...
docs/source/models/vlm.rst
View file @
28e1299e
...
...
@@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
To pass an image to the model, note the following in :class:`vllm.inputs.Prompt
Inputs
`:
To pass an image to the model, note the following in :class:`vllm.inputs.Prompt
Type
`:
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
...
...
tests/async_engine/test_async_llm_engine.py
View file @
28e1299e
...
...
@@ -86,17 +86,19 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
@
pytest
.
mark
.
asyncio
async
def
test_new_requests_event
():
params
=
SamplingParams
()
engine
=
MockAsyncLLMEngine
()
engine
.
start_background_loop
()
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
step_calls
==
0
await
engine
.
add_request
(
"1"
,
""
,
None
)
await
engine
.
add_request
(
"1"
,
""
,
params
)
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
add_request_calls
==
1
assert
engine
.
engine
.
step_calls
==
1
await
engine
.
add_request
(
"2"
,
""
,
None
)
await
engine
.
add_request
(
"2"
,
""
,
params
)
engine
.
engine
.
generate
(
"2"
)
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
)
...
...
@@ -111,7 +113,7 @@ async def test_new_requests_event():
await
asyncio
.
sleep
(
0.001
)
assert
engine
.
engine
.
step_calls
==
old_step_calls
await
engine
.
add_request
(
"3"
,
""
,
None
)
await
engine
.
add_request
(
"3"
,
""
,
params
)
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
step_calls
==
old_step_calls
+
1
...
...
tests/entrypoints/llm/test_encode.py
View file @
28e1299e
...
...
@@ -49,21 +49,6 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
assert
[
o
.
outputs
for
o
in
o1
]
==
[
o
.
outputs
for
o
in
o2
]
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
'prompt'
,
PROMPTS
)
def
test_v1_v2_api_consistency_single_prompt_string
(
llm
:
LLM
,
prompt
):
pooling_params
=
PoolingParams
()
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompts'"
):
v1_output
=
llm
.
encode
(
prompts
=
prompt
,
pooling_params
=
pooling_params
)
v2_output
=
llm
.
encode
(
prompt
,
pooling_params
=
pooling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
v2_output
=
llm
.
encode
({
"prompt"
:
prompt
},
pooling_params
=
pooling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
'prompt_token_ids'
,
TOKEN_IDS
)
def
test_v1_v2_api_consistency_single_prompt_tokens
(
llm
:
LLM
,
...
...
@@ -79,25 +64,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_v1_v2_api_consistency_multi_prompt_string
(
llm
:
LLM
):
pooling_params
=
PoolingParams
()
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompts'"
):
v1_output
=
llm
.
encode
(
prompts
=
PROMPTS
,
pooling_params
=
pooling_params
)
v2_output
=
llm
.
encode
(
PROMPTS
,
pooling_params
=
pooling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
v2_output
=
llm
.
encode
(
[{
"prompt"
:
p
}
for
p
in
PROMPTS
],
pooling_params
=
pooling_params
,
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_v1_v2_api_consistency_multi_prompt_tokens
(
llm
:
LLM
):
pooling_params
=
PoolingParams
()
...
...
tests/entrypoints/llm/test_generate.py
View file @
28e1299e
...
...
@@ -47,23 +47,6 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
assert
[
o
.
outputs
for
o
in
o1
]
==
[
o
.
outputs
for
o
in
o2
]
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
'prompt'
,
PROMPTS
)
def
test_v1_v2_api_consistency_single_prompt_string
(
llm
:
LLM
,
prompt
):
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
)
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompts'"
):
v1_output
=
llm
.
generate
(
prompts
=
prompt
,
sampling_params
=
sampling_params
)
v2_output
=
llm
.
generate
(
prompt
,
sampling_params
=
sampling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
v2_output
=
llm
.
generate
({
"prompt"
:
prompt
},
sampling_params
=
sampling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
'prompt_token_ids'
,
TOKEN_IDS
)
def
test_v1_v2_api_consistency_single_prompt_tokens
(
llm
:
LLM
,
...
...
@@ -79,26 +62,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_v1_v2_api_consistency_multi_prompt_string
(
llm
:
LLM
):
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
)
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompts'"
):
v1_output
=
llm
.
generate
(
prompts
=
PROMPTS
,
sampling_params
=
sampling_params
)
v2_output
=
llm
.
generate
(
PROMPTS
,
sampling_params
=
sampling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
v2_output
=
llm
.
generate
(
[{
"prompt"
:
p
}
for
p
in
PROMPTS
],
sampling_params
=
sampling_params
,
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_v1_v2_api_consistency_multi_prompt_tokens
(
llm
:
LLM
):
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
)
...
...
tests/mq_llm_engine/test_error_handling.py
View file @
28e1299e
...
...
@@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket):
# Throws an error in first forward pass.
with
pytest
.
raises
(
RAISED_ERROR
):
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
request_id
=
uuid
.
uuid4
()):
pass
...
...
@@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket):
# Engine is errored, should get ENGINE_DEAD_ERROR.
with
pytest
.
raises
(
MQEngineDeadError
):
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
request_id
=
uuid
.
uuid4
()):
pass
...
...
@@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):
# Generate call should throw ENGINE_DEAD_ERROR
with
pytest
.
raises
(
MQEngineDeadError
):
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
request_id
=
uuid
.
uuid4
()):
pass
...
...
@@ -160,7 +160,7 @@ async def test_failed_abort(tmp_socket):
# with reference to the original KeyError("foo")
with
pytest
.
raises
(
MQEngineDeadError
)
as
execinfo
:
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(
max_tokens
=
10
),
request_id
=
uuid
.
uuid4
()):
pass
...
...
@@ -183,7 +183,7 @@ async def test_bad_request(tmp_socket):
# Invalid request should fail, but not crash the server.
with
pytest
.
raises
(
ValueError
):
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
request_id
=
"abcd-1"
,
lora_request
=
LoRARequest
(
...
...
@@ -192,7 +192,7 @@ async def test_bad_request(tmp_socket):
pass
# This request should be okay.
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
request_id
=
"abcd-2"
):
pass
...
...
tests/mq_llm_engine/utils.py
View file @
28e1299e
...
...
@@ -20,7 +20,7 @@ async def generate(
count
=
0
async
for
out
in
client
.
generate
(
request_id
=
request_id
,
inputs
=
"Hello my name is Robert and"
,
prompt
=
"Hello my name is Robert and"
,
sampling_params
=
SamplingParams
(
max_tokens
=
num_tokens
,
temperature
=
0
)):
...
...
vllm/__init__.py
View file @
28e1299e
...
...
@@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.entrypoints.llm
import
LLM
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
Prompt
Inputs
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
Prompt
Type
,
TextPrompt
,
TokensPrompt
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.outputs
import
(
CompletionOutput
,
EmbeddingOutput
,
EmbeddingRequestOutput
,
RequestOutput
)
...
...
@@ -19,7 +19,7 @@ __all__ = [
"__version_tuple__"
,
"LLM"
,
"ModelRegistry"
,
"Prompt
Inputs
"
,
"Prompt
Type
"
,
"TextPrompt"
,
"TokensPrompt"
,
"SamplingParams"
,
...
...
vllm/engine/async_llm_engine.py
View file @
28e1299e
...
...
@@ -2,8 +2,8 @@ import asyncio
import
time
import
weakref
from
functools
import
partial
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Coroutine
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
,
overload
)
from
weakref
import
ReferenceType
import
vllm.envs
as
envs
...
...
@@ -17,7 +17,7 @@ from vllm.engine.metrics_types import StatLoggerBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
Prompt
Inputs
from
vllm.inputs
import
Prompt
Type
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
...
...
@@ -28,7 +28,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
weak_bind
from
vllm.utils
import
deprecate_kwargs
,
weak_bind
logger
=
init_logger
(
__name__
)
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
...
...
@@ -402,17 +402,54 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
await
self
.
model_executor
.
stop_remote_worker_execution_loop_async
()
@
overload
# DEPRECATED
async
def
add_request_async
(
self
,
request_id
:
str
,
inputs
:
PromptInputs
,
*
,
inputs
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
...
@
overload
async
def
add_request_async
(
self
,
request_id
:
str
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
async
def
add_request_async
(
self
,
request_id
:
str
,
prompt
:
Optional
[
PromptType
]
=
None
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
None
:
"""Async version of :meth:`add_request`."""
if
inputs
is
not
None
:
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
...
...
@@ -420,7 +457,7 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time
=
time
.
time
()
preprocessed_inputs
=
await
self
.
input_preprocessor
.
preprocess_async
(
inputs
,
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
@@ -774,16 +811,55 @@ class AsyncLLMEngine:
# This method does not need to be async, but kept that way
# for backwards compatibility.
async
def
add_request
(
@
overload
# DEPRECATED
def
add_request
(
self
,
request_id
:
str
,
inputs
:
PromptInputs
,
*
,
inputs
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
Coroutine
[
None
,
None
,
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]]:
...
@
overload
def
add_request
(
self
,
request_id
:
str
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
Coroutine
[
None
,
None
,
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]]:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
async
def
add_request
(
self
,
request_id
:
str
,
prompt
:
Optional
[
PromptType
]
=
None
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]:
if
inputs
is
not
None
:
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
if
not
self
.
is_running
:
if
self
.
start_engine_loop
:
self
.
start_background_loop
()
...
...
@@ -797,7 +873,7 @@ class AsyncLLMEngine:
stream
=
self
.
_request_tracker
.
add_request
(
request_id
,
verbose
=
self
.
log_requests
,
inputs
=
inputs
,
prompt
=
prompt
,
params
=
params
,
arrival_time
=
arrival_time
or
time
.
time
(),
lora_request
=
lora_request
,
...
...
@@ -808,7 +884,7 @@ class AsyncLLMEngine:
async
def
generate
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
...
@@ -822,8 +898,7 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
...
...
@@ -881,7 +956,7 @@ class AsyncLLMEngine:
"""
async
for
output
in
await
self
.
add_request
(
request_id
,
inputs
,
prompt
,
sampling_params
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
...
...
@@ -891,7 +966,7 @@ class AsyncLLMEngine:
async
def
encode
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
...
@@ -904,8 +979,7 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
...
...
@@ -959,7 +1033,7 @@ class AsyncLLMEngine:
"""
async
for
output
in
await
self
.
add_request
(
request_id
,
inputs
,
prompt
,
pooling_params
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
...
...
vllm/engine/llm_engine.py
View file @
28e1299e
...
...
@@ -6,7 +6,7 @@ from functools import partial
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Type
,
Union
from
typing
import
Set
,
Type
,
Union
,
overload
import
torch
from
typing_extensions
import
TypeVar
...
...
@@ -29,7 +29,7 @@ from vllm.executor.executor_base import ExecutorBase
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
(
INPUT_REGISTRY
,
EncoderDecoderLLMInputs
,
InputRegistry
,
LLMInputs
,
Prompt
Inputs
)
InputRegistry
,
LLMInputs
,
Prompt
Type
)
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
...
...
@@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.utils
import
Counter
,
Device
,
weak_bind
from
vllm.utils
import
Counter
,
Device
,
deprecate_kwargs
,
weak_bind
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
...
...
@@ -689,16 +689,51 @@ class LLMEngine:
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
self
.
model_executor
.
stop_remote_worker_execution_loop
()
@
overload
# DEPRECATED
def
add_request
(
self
,
request_id
:
str
,
inputs
:
PromptInputs
,
*
,
inputs
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
...
@
overload
def
add_request
(
self
,
request_id
:
str
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
add_request
(
self
,
request_id
:
str
,
prompt
:
Optional
[
PromptType
]
=
None
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
None
:
"""Add a request to the engine's request pool.
...
...
@@ -708,8 +743,7 @@ class LLMEngine:
Args:
request_id: The unique ID of the request.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation.
...
...
@@ -744,6 +778,10 @@ class LLMEngine:
>>> # continue the request processing
>>> ...
"""
if
inputs
is
not
None
:
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
...
...
@@ -756,7 +794,7 @@ class LLMEngine:
arrival_time
=
time
.
time
()
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
inputs
,
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
vllm/engine/multiprocessing/__init__.py
View file @
28e1299e
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
List
,
Mapping
,
Optional
,
Union
from
typing
import
List
,
Mapping
,
Optional
,
Union
,
overload
from
vllm
import
PoolingParams
from
vllm.inputs
import
Prompt
Inputs
from
vllm.inputs
import
Prompt
Type
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
deprecate_kwargs
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
...
...
@@ -23,13 +24,67 @@ class MQEngineDeadError(RuntimeError):
@
dataclass
class
RPCProcessRequest
:
inputs
:
Prompt
Inputs
prompt
:
Prompt
Type
params
:
Union
[
SamplingParams
,
PoolingParams
]
request_id
:
str
lora_request
:
Optional
[
LoRARequest
]
=
None
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
@
overload
# DEPRECATED
def
__init__
(
self
,
*
,
inputs
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
...
@
overload
def
__init__
(
self
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
__init__
(
self
,
prompt
:
Optional
[
PromptType
]
=
None
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
None
:
if
inputs
is
not
None
:
prompt
=
inputs
assert
(
prompt
is
not
None
and
params
is
not
None
and
request_id
is
not
None
)
super
().
__init__
()
self
.
prompt
=
prompt
self
.
params
=
params
self
.
request_id
=
request_id
self
.
lora_request
=
lora_request
self
.
trace_headers
=
trace_headers
self
.
prompt_adapter_request
=
prompt_adapter_request
@
dataclass
class
RPCError
:
...
...
vllm/engine/multiprocessing/client.py
View file @
28e1299e
...
...
@@ -3,7 +3,7 @@ import copy
import
pickle
from
contextlib
import
contextmanager
,
suppress
from
typing
import
(
Any
,
AsyncGenerator
,
Dict
,
Iterator
,
Mapping
,
Optional
,
Union
)
Union
,
overload
)
import
cloudpickle
import
zmq
...
...
@@ -24,13 +24,14 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCStartupRequest
,
RPCStartupResponse
)
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.inputs
import
Prompt
Inputs
from
vllm.inputs
import
Prompt
Type
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.utils
import
deprecate_kwargs
logger
=
init_logger
(
__name__
)
...
...
@@ -366,14 +367,45 @@ class MQLLMEngineClient:
def
dead_error
(
self
)
->
BaseException
:
return
ENGINE_DEAD_ERROR
(
self
.
_errored_with
)
@
overload
# DEPRECATED
def
generate
(
self
,
inputs
:
PromptInputs
,
*
,
inputs
:
PromptType
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
@
overload
def
generate
(
self
,
prompt
:
PromptType
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
generate
(
self
,
prompt
:
Optional
[
PromptType
]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
# DEPRECATED
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate outputs for a request.
...
...
@@ -382,8 +414,7 @@ class MQLLMEngineClient:
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
...
...
@@ -392,17 +423,51 @@ class MQLLMEngineClient:
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
"""
return
self
.
_process_request
(
inputs
,
sampling_params
,
request_id
,
if
inputs
is
not
None
:
prompt
=
inputs
assert
(
prompt
is
not
None
and
sampling_params
is
not
None
and
request_id
is
not
None
)
return
self
.
_process_request
(
prompt
,
sampling_params
,
request_id
,
lora_request
,
trace_headers
,
prompt_adapter_request
)
@
overload
# DEPRECATED
def
encode
(
self
,
inputs
:
PromptInputs
,
*
,
inputs
:
PromptType
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
...
@
overload
def
encode
(
self
,
prompt
:
PromptType
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
encode
(
self
,
prompt
:
Optional
[
PromptType
]
=
None
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
# DEPRECATED
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model.
...
...
@@ -411,8 +476,7 @@ class MQLLMEngineClient:
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
...
...
@@ -423,12 +487,17 @@ class MQLLMEngineClient:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
"""
return
self
.
_process_request
(
inputs
,
pooling_params
,
request_id
,
if
inputs
is
not
None
:
prompt
=
inputs
assert
(
prompt
is
not
None
and
pooling_params
is
not
None
and
request_id
is
not
None
)
return
self
.
_process_request
(
prompt
,
pooling_params
,
request_id
,
lora_request
,
trace_headers
)
async
def
_process_request
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
...
@@ -461,7 +530,7 @@ class MQLLMEngineClient:
request_bytes
=
pickle
.
dumps
(
RPCProcessRequest
(
inputs
=
inputs
,
prompt
=
prompt
,
params
=
params
,
request_id
=
request_id
,
lora_request
=
lora_request
,
...
...
vllm/engine/multiprocessing/engine.py
View file @
28e1299e
...
...
@@ -271,7 +271,7 @@ class MQLLMEngine:
try
:
self
.
engine
.
add_request
(
request_id
=
request_id
,
inputs
=
request
.
inputs
,
prompt
=
request
.
prompt
,
params
=
request
.
params
,
lora_request
=
request
.
lora_request
,
trace_headers
=
request
.
trace_headers
,
...
...
vllm/engine/protocol.py
View file @
28e1299e
...
...
@@ -3,7 +3,7 @@ from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.inputs.data
import
Prompt
Inputs
from
vllm.inputs.data
import
Prompt
Type
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
...
...
@@ -35,19 +35,19 @@ class EngineClient(Protocol):
def
generate
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate
s
outputs for a request"""
"""Generate outputs for a request
.
"""
...
def
encode
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
...
vllm/entrypoints/llm.py
View file @
28e1299e
...
...
@@ -12,7 +12,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template
,
apply_mistral_chat_template
,
parse_chat_messages
)
from
vllm.inputs
import
Prompt
Inputs
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
Prompt
Type
,
TextPrompt
,
TokensPrompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
...
...
@@ -293,8 +293,8 @@ class LLM:
@
overload
def
generate
(
self
,
inpu
ts
:
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
/
,
# We may enable `inputs` keyword after removing the old API
promp
ts
:
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
/
,
*
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
Sequence
[
SamplingParams
]]]
=
None
,
...
...
@@ -304,14 +304,13 @@ class LLM:
...
@
deprecate_kwargs
(
"prompts"
,
"prompt_token_ids"
,
is_deprecated
=
lambda
:
LLM
.
DEPRECATE_LEGACY
,
additional_message
=
"Please use the '
inpu
ts' parameter instead."
,
additional_message
=
"Please use the '
promp
ts' parameter instead."
,
)
def
generate
(
self
,
prompts
:
Union
[
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
prompts
:
Union
[
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
Sequence
[
SamplingParams
]]]
=
None
,
...
...
@@ -330,7 +329,9 @@ class LLM:
into a single list and pass it to this method.
Args:
inputs: A list of inputs to generate completions for.
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
...
...
@@ -358,12 +359,13 @@ class LLM:
"models (XForCausalLM, XForConditionalGeneration)."
)
if
prompt_token_ids
is
not
None
:
inpu
ts
=
self
.
_convert_v1_inputs
(
parsed_promp
ts
=
self
.
_convert_v1_inputs
(
prompts
=
cast
(
Optional
[
Union
[
str
,
List
[
str
]]],
prompts
),
prompt_token_ids
=
prompt_token_ids
,
)
else
:
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
parsed_prompts
=
cast
(
Union
[
PromptType
,
Sequence
[
PromptType
]],
prompts
)
if
isinstance
(
guided_options_request
,
dict
):
if
len
(
guided_options_request
)
>
1
:
...
...
@@ -378,7 +380,7 @@ class LLM:
sampling_params
=
SamplingParams
()
self
.
_validate_and_add_requests
(
inputs
=
inpu
ts
,
prompts
=
parsed_promp
ts
,
params
=
sampling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
@@ -648,8 +650,8 @@ class LLM:
@
overload
def
encode
(
self
,
inpu
ts
:
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
/
,
# We may enable `inputs` keyword after removing the old API
promp
ts
:
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
/
,
*
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
...
...
@@ -659,14 +661,13 @@ class LLM:
...
@
deprecate_kwargs
(
"prompts"
,
"prompt_token_ids"
,
is_deprecated
=
lambda
:
LLM
.
DEPRECATE_LEGACY
,
additional_message
=
"Please use the '
inpu
ts' parameter instead."
,
additional_message
=
"Please use the '
promp
ts' parameter instead."
,
)
def
encode
(
self
,
prompts
:
Union
[
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
prompts
:
Union
[
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
...
...
@@ -682,9 +683,9 @@ class LLM:
into a single list and pass it to this method.
Args:
inpu
ts: The
inpu
ts to the LLM. You may pass a sequence of
inputs for
batch inference. See :class:`~vllm.inputs.Prompt
Inputs
`
for more details about the format of each
input
.
promp
ts: The
promp
ts to the LLM. You may pass a sequence of
prompts
for
batch inference. See :class:`~vllm.inputs.Prompt
Type
`
for more details about the format of each
prompts
.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
...
...
@@ -707,19 +708,20 @@ class LLM:
)
if
prompt_token_ids
is
not
None
:
inpu
ts
=
self
.
_convert_v1_inputs
(
parsed_promp
ts
=
self
.
_convert_v1_inputs
(
prompts
=
cast
(
Optional
[
Union
[
str
,
List
[
str
]]],
prompts
),
prompt_token_ids
=
prompt_token_ids
,
)
else
:
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
parsed_prompts
=
cast
(
Union
[
PromptType
,
Sequence
[
PromptType
]],
prompts
)
if
pooling_params
is
None
:
# Use default pooling params.
pooling_params
=
PoolingParams
()
self
.
_validate_and_add_requests
(
inputs
=
inpu
ts
,
prompts
=
parsed_promp
ts
,
params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
@@ -763,9 +765,9 @@ class LLM:
raise
ValueError
(
"Either prompts or prompt_token_ids must be "
"provided."
)
inpu
ts
:
List
[
Prompt
Inputs
]
=
[]
parsed_promp
ts
:
List
[
Prompt
Type
]
=
[]
for
i
in
range
(
num_requests
):
item
:
Prompt
Inputs
item
:
Prompt
Type
if
prompts
is
not
None
:
item
=
TextPrompt
(
prompt
=
prompts
[
i
])
...
...
@@ -774,13 +776,13 @@ class LLM:
else
:
raise
AssertionError
inpu
ts
.
append
(
item
)
parsed_promp
ts
.
append
(
item
)
return
inpu
ts
return
parsed_promp
ts
def
_validate_and_add_requests
(
self
,
inpu
ts
:
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
promp
ts
:
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
params
:
Union
[
SamplingParams
,
Sequence
[
SamplingParams
],
PoolingParams
,
Sequence
[
PoolingParams
]],
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
...
...
@@ -788,11 +790,11 @@ class LLM:
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
,
priority
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
if
isinstance
(
inpu
ts
,
(
str
,
dict
)):
if
isinstance
(
promp
ts
,
(
str
,
dict
)):
# Convert a single prompt to a list.
inpu
ts
=
[
inpu
ts
]
promp
ts
=
[
promp
ts
]
num_requests
=
len
(
inpu
ts
)
num_requests
=
len
(
promp
ts
)
if
isinstance
(
params
,
list
)
and
len
(
params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and params "
"must be the same."
)
...
...
@@ -809,9 +811,9 @@ class LLM:
sp
.
output_kind
=
RequestOutputKind
.
FINAL_ONLY
# Add requests to the engine.
for
i
,
request_inputs
in
enumerate
(
inpu
ts
):
for
i
,
prompt
in
enumerate
(
promp
ts
):
self
.
_add_request
(
request_inputs
,
prompt
,
params
[
i
]
if
isinstance
(
params
,
Sequence
)
else
params
,
lora_request
=
lora_request
[
i
]
if
isinstance
(
lora_request
,
Sequence
)
else
lora_request
,
...
...
@@ -821,7 +823,7 @@ class LLM:
def
_add_request
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
...
...
@@ -830,7 +832,7 @@ class LLM:
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_engine
.
add_request
(
request_id
,
inputs
,
prompt
,
params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
vllm/inputs/__init__.py
View file @
28e1299e
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
LLMInputs
,
Prompt
Inputs
,
SingletonPrompt
Inputs
,
TextPrompt
,
LLMInputs
,
Prompt
Type
,
SingletonPrompt
,
TextPrompt
,
TokensPrompt
,
build_explicit_enc_dec_prompt
,
to_enc_dec_tuple_list
,
zip_enc_dec_prompts
)
from
.registry
import
InputContext
,
InputRegistry
...
...
@@ -16,8 +16,8 @@ See also:
__all__
=
[
"TextPrompt"
,
"TokensPrompt"
,
"Prompt
Inputs
"
,
"SingletonPrompt
Inputs
"
,
"Prompt
Type
"
,
"SingletonPrompt"
,
"ExplicitEncoderDecoderPrompt"
,
"LLMInputs"
,
"EncoderDecoderLLMInputs"
,
...
...
@@ -28,3 +28,17 @@ __all__ = [
"InputContext"
,
"InputRegistry"
,
]
def
__getattr__
(
name
:
str
):
if
name
==
"PromptInput"
:
import
warnings
msg
=
(
"PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version."
)
warnings
.
warn
(
DeprecationWarning
(
msg
),
stacklevel
=
2
)
return
PromptType
raise
AttributeError
(
f
"module
{
__name__
!
r
}
has no attribute
{
name
!
r
}
"
)
vllm/inputs/data.py
View file @
28e1299e
...
...
@@ -33,7 +33,7 @@ class TokensPrompt(TypedDict):
"""
SingletonPrompt
Inputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
SingletonPrompt
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
"""
Set of possible schemas for a single LLM input:
...
...
@@ -46,7 +46,7 @@ which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
A prompt of type :class:`SingletonPrompt
Inputs
` may be employed
A prompt of type :class:`SingletonPrompt` may be employed
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
...
...
@@ -55,33 +55,33 @@ more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
"""
_T1_co
=
TypeVar
(
"_T1_co"
,
bound
=
SingletonPrompt
Inputs
,
default
=
SingletonPrompt
Inputs
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
,
covariant
=
True
)
_T2_co
=
TypeVar
(
"_T2_co"
,
bound
=
SingletonPrompt
Inputs
,
default
=
SingletonPrompt
Inputs
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
,
covariant
=
True
)
# TODO: Make fields ReadOnly once mypy supports it
class
ExplicitEncoderDecoderPrompt
(
TypedDict
,
Generic
[
_T1_co
,
_T2_co
]):
"""
Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a
decoder prompt.
"""
Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a
decoder prompt.
The encoder and decoder prompts, respectively,
may formatted according to any of the
:class:`SingletonPrompt
Inputs
` schemas, and are not
:class:`SingletonPrompt` schemas, and are not
required to have the same schema.
Only the encoder prompt may have multi-modal data.
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
and that the
:code:
`encoder_prompt` and
:code:
`decoder_prompt`
fields of this data structure themselves must be
:class:`SingletonPrompt
Inputs
` instances.
:class:`SingletonPrompt` instances.
"""
encoder_prompt
:
_T1_co
...
...
@@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
decoder_prompt
:
Optional
[
_T2_co
]
Prompt
Inputs
=
Union
[
SingletonPrompt
Inputs
,
ExplicitEncoderDecoderPrompt
]
Prompt
Type
=
Union
[
SingletonPrompt
,
ExplicitEncoderDecoderPrompt
]
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
...
...
@@ -140,12 +140,8 @@ class EncoderDecoderLLMInputs(LLMInputs):
"""
_T1
=
TypeVar
(
"_T1"
,
bound
=
SingletonPromptInputs
,
default
=
SingletonPromptInputs
)
_T2
=
TypeVar
(
"_T2"
,
bound
=
SingletonPromptInputs
,
default
=
SingletonPromptInputs
)
_T1
=
TypeVar
(
"_T1"
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
)
_T2
=
TypeVar
(
"_T2"
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
)
def
build_explicit_enc_dec_prompt
(
...
...
@@ -176,3 +172,17 @@ def to_enc_dec_tuple_list(
return
[(
enc_dec_prompt
[
"encoder_prompt"
],
enc_dec_prompt
[
"decoder_prompt"
])
for
enc_dec_prompt
in
enc_dec_prompts
]
def
__getattr__
(
name
:
str
):
if
name
==
"PromptInput"
:
import
warnings
msg
=
(
"PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version."
)
warnings
.
warn
(
DeprecationWarning
(
msg
),
stacklevel
=
2
)
return
PromptType
raise
AttributeError
(
f
"module
{
__name__
!
r
}
has no attribute
{
name
!
r
}
"
)
vllm/inputs/parse.py
View file @
28e1299e
...
...
@@ -5,7 +5,7 @@ from typing_extensions import TypeIs
from
vllm.utils
import
is_list_of
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
LLMInputs
,
Prompt
Inputs
,
SingletonPrompt
Inputs
,
TextPrompt
,
LLMInputs
,
Prompt
Type
,
SingletonPrompt
,
TextPrompt
,
TokensPrompt
)
...
...
@@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict):
def
parse_singleton_prompt
(
inputs
:
SingletonPrompt
Inputs
,
prompt
:
SingletonPrompt
,
)
->
Union
[
ParsedStrPrompt
,
ParsedTextPrompt
,
ParsedTokensPrompt
]:
if
isinstance
(
inputs
,
str
):
return
ParsedStrPrompt
(
type
=
"str"
,
content
=
inputs
)
elif
isinstance
(
inputs
,
dict
):
if
"prompt_token_ids"
in
inputs
:
if
isinstance
(
prompt
,
str
):
return
ParsedStrPrompt
(
type
=
"str"
,
content
=
prompt
)
elif
isinstance
(
prompt
,
dict
):
if
"prompt_token_ids"
in
prompt
:
return
ParsedTokensPrompt
(
type
=
"tokens"
,
content
=
inputs
)
# type: ignore
elif
"prompt"
in
inputs
:
return
ParsedTextPrompt
(
type
=
"text"
,
content
=
inputs
)
content
=
prompt
)
# type: ignore
elif
"prompt"
in
prompt
:
return
ParsedTextPrompt
(
type
=
"text"
,
content
=
prompt
)
raise
TypeError
(
"inputs must be a string, TextPrompt, or TokensPrompt"
)
def
is_explicit_encoder_decoder_prompt
(
inputs
:
Prompt
Inputs
)
->
TypeIs
[
ExplicitEncoderDecoderPrompt
]:
return
isinstance
(
inputs
,
dict
)
and
"encoder_prompt"
in
inputs
prompt
:
Prompt
Type
)
->
TypeIs
[
ExplicitEncoderDecoderPrompt
]:
return
isinstance
(
prompt
,
dict
)
and
"encoder_prompt"
in
prompt
def
is_valid_encoder_decoder_llm_inputs
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment