Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4f1ba084
"vscode:/vscode.git/clone" did not exist on "c42ff4f4fdc4a4d48ccef18b8067995f6c19e6ec"
Unverified
Commit
4f1ba084
authored
Sep 25, 2024
by
Simon Mo
Committed by
GitHub
Sep 25, 2024
Browse files
Revert "rename PromptInputs and inputs with backward compatibility (#8760) (#8810)
parent
873edda6
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
202 additions
and
395 deletions
+202
-395
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+4
-4
docs/source/dev/multimodal/multimodal_index.rst
docs/source/dev/multimodal/multimodal_index.rst
+1
-1
docs/source/dev/offline_inference/llm_inputs.rst
docs/source/dev/offline_inference/llm_inputs.rst
+1
-1
docs/source/models/vlm.rst
docs/source/models/vlm.rst
+1
-1
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+3
-5
tests/entrypoints/llm/test_encode.py
tests/entrypoints/llm/test_encode.py
+34
-0
tests/entrypoints/llm/test_generate.py
tests/entrypoints/llm/test_generate.py
+37
-0
tests/mq_llm_engine/test_error_handling.py
tests/mq_llm_engine/test_error_handling.py
+6
-6
tests/mq_llm_engine/utils.py
tests/mq_llm_engine/utils.py
+1
-1
vllm/__init__.py
vllm/__init__.py
+2
-2
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+18
-92
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+7
-45
vllm/engine/multiprocessing/__init__.py
vllm/engine/multiprocessing/__init__.py
+3
-58
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+13
-82
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+1
-1
vllm/engine/protocol.py
vllm/engine/protocol.py
+4
-4
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+33
-35
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+3
-17
vllm/inputs/data.py
vllm/inputs/data.py
+19
-29
vllm/inputs/parse.py
vllm/inputs/parse.py
+11
-11
No files found.
benchmarks/benchmark_latency.py
View file @
4f1ba084
...
...
@@ -11,7 +11,7 @@ from tqdm import tqdm
from
vllm
import
LLM
,
SamplingParams
from
vllm.engine.arg_utils
import
DEVICE_OPTIONS
,
EngineArgs
from
vllm.inputs
import
Prompt
Type
from
vllm.inputs
import
Prompt
Inputs
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.utils
import
FlexibleArgumentParser
...
...
@@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids
=
np
.
random
.
randint
(
10000
,
size
=
(
args
.
batch_size
,
args
.
input_len
))
dummy_
promp
ts
:
List
[
Prompt
Type
]
=
[{
dummy_
inpu
ts
:
List
[
Prompt
Inputs
]
=
[{
"prompt_token_ids"
:
batch
}
for
batch
in
dummy_prompt_token_ids
.
tolist
()]
...
...
@@ -74,13 +74,13 @@ def main(args: argparse.Namespace):
],
on_trace_ready
=
torch
.
profiler
.
tensorboard_trace_handler
(
str
(
profile_dir
)))
as
p
:
llm
.
generate
(
dummy_
promp
ts
,
llm
.
generate
(
dummy_
inpu
ts
,
sampling_params
=
sampling_params
,
use_tqdm
=
False
)
print
(
p
.
key_averages
())
else
:
start_time
=
time
.
perf_counter
()
llm
.
generate
(
dummy_
promp
ts
,
llm
.
generate
(
dummy_
inpu
ts
,
sampling_params
=
sampling_params
,
use_tqdm
=
False
)
end_time
=
time
.
perf_counter
()
...
...
docs/source/dev/multimodal/multimodal_index.rst
View file @
4f1ba084
...
...
@@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.Prompt
Type
`.
via the ``multi_modal_data`` field in :class:`vllm.inputs.Prompt
Inputs
`.
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`.
...
...
docs/source/dev/offline_inference/llm_inputs.rst
View file @
4f1ba084
LLM Inputs
==========
.. autodata:: vllm.inputs.Prompt
Type
.. autodata:: vllm.inputs.Prompt
Inputs
.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
...
...
docs/source/models/vlm.rst
View file @
4f1ba084
...
...
@@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
To pass an image to the model, note the following in :class:`vllm.inputs.Prompt
Type
`:
To pass an image to the model, note the following in :class:`vllm.inputs.Prompt
Inputs
`:
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
...
...
tests/async_engine/test_async_llm_engine.py
View file @
4f1ba084
...
...
@@ -86,19 +86,17 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
@
pytest
.
mark
.
asyncio
async
def
test_new_requests_event
():
params
=
SamplingParams
()
engine
=
MockAsyncLLMEngine
()
engine
.
start_background_loop
()
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
step_calls
==
0
await
engine
.
add_request
(
"1"
,
""
,
params
)
await
engine
.
add_request
(
"1"
,
""
,
None
)
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
add_request_calls
==
1
assert
engine
.
engine
.
step_calls
==
1
await
engine
.
add_request
(
"2"
,
""
,
params
)
await
engine
.
add_request
(
"2"
,
""
,
None
)
engine
.
engine
.
generate
(
"2"
)
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
)
...
...
@@ -113,7 +111,7 @@ async def test_new_requests_event():
await
asyncio
.
sleep
(
0.001
)
assert
engine
.
engine
.
step_calls
==
old_step_calls
await
engine
.
add_request
(
"3"
,
""
,
params
)
await
engine
.
add_request
(
"3"
,
""
,
None
)
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
step_calls
==
old_step_calls
+
1
...
...
tests/entrypoints/llm/test_encode.py
View file @
4f1ba084
...
...
@@ -49,6 +49,21 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
assert
[
o
.
outputs
for
o
in
o1
]
==
[
o
.
outputs
for
o
in
o2
]
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
'prompt'
,
PROMPTS
)
def
test_v1_v2_api_consistency_single_prompt_string
(
llm
:
LLM
,
prompt
):
pooling_params
=
PoolingParams
()
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompts'"
):
v1_output
=
llm
.
encode
(
prompts
=
prompt
,
pooling_params
=
pooling_params
)
v2_output
=
llm
.
encode
(
prompt
,
pooling_params
=
pooling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
v2_output
=
llm
.
encode
({
"prompt"
:
prompt
},
pooling_params
=
pooling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
'prompt_token_ids'
,
TOKEN_IDS
)
def
test_v1_v2_api_consistency_single_prompt_tokens
(
llm
:
LLM
,
...
...
@@ -64,6 +79,25 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_v1_v2_api_consistency_multi_prompt_string
(
llm
:
LLM
):
pooling_params
=
PoolingParams
()
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompts'"
):
v1_output
=
llm
.
encode
(
prompts
=
PROMPTS
,
pooling_params
=
pooling_params
)
v2_output
=
llm
.
encode
(
PROMPTS
,
pooling_params
=
pooling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
v2_output
=
llm
.
encode
(
[{
"prompt"
:
p
}
for
p
in
PROMPTS
],
pooling_params
=
pooling_params
,
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_v1_v2_api_consistency_multi_prompt_tokens
(
llm
:
LLM
):
pooling_params
=
PoolingParams
()
...
...
tests/entrypoints/llm/test_generate.py
View file @
4f1ba084
...
...
@@ -47,6 +47,23 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
assert
[
o
.
outputs
for
o
in
o1
]
==
[
o
.
outputs
for
o
in
o2
]
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
'prompt'
,
PROMPTS
)
def
test_v1_v2_api_consistency_single_prompt_string
(
llm
:
LLM
,
prompt
):
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
)
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompts'"
):
v1_output
=
llm
.
generate
(
prompts
=
prompt
,
sampling_params
=
sampling_params
)
v2_output
=
llm
.
generate
(
prompt
,
sampling_params
=
sampling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
v2_output
=
llm
.
generate
({
"prompt"
:
prompt
},
sampling_params
=
sampling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
'prompt_token_ids'
,
TOKEN_IDS
)
def
test_v1_v2_api_consistency_single_prompt_tokens
(
llm
:
LLM
,
...
...
@@ -62,6 +79,26 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_v1_v2_api_consistency_multi_prompt_string
(
llm
:
LLM
):
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
)
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompts'"
):
v1_output
=
llm
.
generate
(
prompts
=
PROMPTS
,
sampling_params
=
sampling_params
)
v2_output
=
llm
.
generate
(
PROMPTS
,
sampling_params
=
sampling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
v2_output
=
llm
.
generate
(
[{
"prompt"
:
p
}
for
p
in
PROMPTS
],
sampling_params
=
sampling_params
,
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_v1_v2_api_consistency_multi_prompt_tokens
(
llm
:
LLM
):
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
)
...
...
tests/mq_llm_engine/test_error_handling.py
View file @
4f1ba084
...
...
@@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket):
# Throws an error in first forward pass.
with
pytest
.
raises
(
RAISED_ERROR
):
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
request_id
=
uuid
.
uuid4
()):
pass
...
...
@@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket):
# Engine is errored, should get ENGINE_DEAD_ERROR.
with
pytest
.
raises
(
MQEngineDeadError
):
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
request_id
=
uuid
.
uuid4
()):
pass
...
...
@@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):
# Generate call should throw ENGINE_DEAD_ERROR
with
pytest
.
raises
(
MQEngineDeadError
):
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
request_id
=
uuid
.
uuid4
()):
pass
...
...
@@ -160,7 +160,7 @@ async def test_failed_abort(tmp_socket):
# with reference to the original KeyError("foo")
with
pytest
.
raises
(
MQEngineDeadError
)
as
execinfo
:
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
inputs
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(
max_tokens
=
10
),
request_id
=
uuid
.
uuid4
()):
pass
...
...
@@ -183,7 +183,7 @@ async def test_bad_request(tmp_socket):
# Invalid request should fail, but not crash the server.
with
pytest
.
raises
(
ValueError
):
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
request_id
=
"abcd-1"
,
lora_request
=
LoRARequest
(
...
...
@@ -192,7 +192,7 @@ async def test_bad_request(tmp_socket):
pass
# This request should be okay.
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
request_id
=
"abcd-2"
):
pass
...
...
tests/mq_llm_engine/utils.py
View file @
4f1ba084
...
...
@@ -20,7 +20,7 @@ async def generate(
count
=
0
async
for
out
in
client
.
generate
(
request_id
=
request_id
,
prompt
=
"Hello my name is Robert and"
,
inputs
=
"Hello my name is Robert and"
,
sampling_params
=
SamplingParams
(
max_tokens
=
num_tokens
,
temperature
=
0
)):
...
...
vllm/__init__.py
View file @
4f1ba084
...
...
@@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.entrypoints.llm
import
LLM
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
Prompt
Type
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
Prompt
Inputs
,
TextPrompt
,
TokensPrompt
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.outputs
import
(
CompletionOutput
,
EmbeddingOutput
,
EmbeddingRequestOutput
,
RequestOutput
)
...
...
@@ -19,7 +19,7 @@ __all__ = [
"__version_tuple__"
,
"LLM"
,
"ModelRegistry"
,
"Prompt
Type
"
,
"Prompt
Inputs
"
,
"TextPrompt"
,
"TokensPrompt"
,
"SamplingParams"
,
...
...
vllm/engine/async_llm_engine.py
View file @
4f1ba084
...
...
@@ -2,8 +2,8 @@ import asyncio
import
time
import
weakref
from
functools
import
partial
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Coroutine
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
,
overload
)
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
from
weakref
import
ReferenceType
import
vllm.envs
as
envs
...
...
@@ -17,7 +17,7 @@ from vllm.engine.metrics_types import StatLoggerBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
Prompt
Type
from
vllm.inputs
import
Prompt
Inputs
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
...
...
@@ -28,7 +28,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
deprecate_kwargs
,
weak_bind
from
vllm.utils
import
weak_bind
logger
=
init_logger
(
__name__
)
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
...
...
@@ -402,54 +402,17 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
await
self
.
model_executor
.
stop_remote_worker_execution_loop_async
()
@
overload
# DEPRECATED
async
def
add_request_async
(
self
,
request_id
:
str
,
*
,
inputs
:
PromptType
,
inputs
:
PromptInputs
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
...
@
overload
async
def
add_request_async
(
self
,
request_id
:
str
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
async
def
add_request_async
(
self
,
request_id
:
str
,
prompt
:
Optional
[
PromptType
]
=
None
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
None
:
"""Async version of :meth:`add_request`."""
if
inputs
is
not
None
:
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
...
...
@@ -457,7 +420,7 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time
=
time
.
time
()
preprocessed_inputs
=
await
self
.
input_preprocessor
.
preprocess_async
(
prompt
,
inputs
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
@@ -811,55 +774,16 @@ class AsyncLLMEngine:
# This method does not need to be async, but kept that way
# for backwards compatibility.
@
overload
# DEPRECATED
def
add_request
(
self
,
request_id
:
str
,
*
,
inputs
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
Coroutine
[
None
,
None
,
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]]:
...
@
overload
def
add_request
(
self
,
request_id
:
str
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
Coroutine
[
None
,
None
,
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]]:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
async
def
add_request
(
self
,
request_id
:
str
,
prompt
:
Optional
[
PromptType
]
=
None
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]
]
=
None
,
inputs
:
PromptInputs
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]:
if
inputs
is
not
None
:
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
if
not
self
.
is_running
:
if
self
.
start_engine_loop
:
self
.
start_background_loop
()
...
...
@@ -873,7 +797,7 @@ class AsyncLLMEngine:
stream
=
self
.
_request_tracker
.
add_request
(
request_id
,
verbose
=
self
.
log_requests
,
prompt
=
prompt
,
inputs
=
inputs
,
params
=
params
,
arrival_time
=
arrival_time
or
time
.
time
(),
lora_request
=
lora_request
,
...
...
@@ -884,7 +808,7 @@ class AsyncLLMEngine:
async
def
generate
(
self
,
prompt
:
Prompt
Type
,
inputs
:
Prompt
Inputs
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
...
@@ -898,7 +822,8 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
...
...
@@ -956,7 +881,7 @@ class AsyncLLMEngine:
"""
async
for
output
in
await
self
.
add_request
(
request_id
,
prompt
,
inputs
,
sampling_params
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
...
...
@@ -966,7 +891,7 @@ class AsyncLLMEngine:
async
def
encode
(
self
,
prompt
:
Prompt
Type
,
inputs
:
Prompt
Inputs
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
...
@@ -979,7 +904,8 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
...
...
@@ -1033,7 +959,7 @@ class AsyncLLMEngine:
"""
async
for
output
in
await
self
.
add_request
(
request_id
,
prompt
,
inputs
,
pooling_params
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
...
...
vllm/engine/llm_engine.py
View file @
4f1ba084
...
...
@@ -6,7 +6,7 @@ from functools import partial
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Type
,
Union
,
overload
from
typing
import
Set
,
Type
,
Union
import
torch
from
typing_extensions
import
TypeVar
...
...
@@ -29,7 +29,7 @@ from vllm.executor.executor_base import ExecutorBase
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
(
INPUT_REGISTRY
,
EncoderDecoderLLMInputs
,
InputRegistry
,
LLMInputs
,
Prompt
Type
)
InputRegistry
,
LLMInputs
,
Prompt
Inputs
)
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
...
...
@@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.utils
import
Counter
,
Device
,
deprecate_kwargs
,
weak_bind
from
vllm.utils
import
Counter
,
Device
,
weak_bind
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
...
...
@@ -689,51 +689,16 @@ class LLMEngine:
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
self
.
model_executor
.
stop_remote_worker_execution_loop
()
@
overload
# DEPRECATED
def
add_request
(
self
,
request_id
:
str
,
*
,
inputs
:
PromptType
,
inputs
:
PromptInputs
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
...
@
overload
def
add_request
(
self
,
request_id
:
str
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
add_request
(
self
,
request_id
:
str
,
prompt
:
Optional
[
PromptType
]
=
None
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
None
:
"""Add a request to the engine's request pool.
...
...
@@ -743,7 +708,8 @@ class LLMEngine:
Args:
request_id: The unique ID of the request.
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation.
...
...
@@ -778,10 +744,6 @@ class LLMEngine:
>>> # continue the request processing
>>> ...
"""
if
inputs
is
not
None
:
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
...
...
@@ -794,7 +756,7 @@ class LLMEngine:
arrival_time
=
time
.
time
()
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
prompt
,
inputs
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
vllm/engine/multiprocessing/__init__.py
View file @
4f1ba084
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
List
,
Mapping
,
Optional
,
Union
,
overload
from
typing
import
List
,
Mapping
,
Optional
,
Union
from
vllm
import
PoolingParams
from
vllm.inputs
import
Prompt
Type
from
vllm.inputs
import
Prompt
Inputs
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
deprecate_kwargs
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
...
...
@@ -24,67 +23,13 @@ class MQEngineDeadError(RuntimeError):
@
dataclass
class
RPCProcessRequest
:
prompt
:
Prompt
Type
inputs
:
Prompt
Inputs
params
:
Union
[
SamplingParams
,
PoolingParams
]
request_id
:
str
lora_request
:
Optional
[
LoRARequest
]
=
None
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
@
overload
# DEPRECATED
def
__init__
(
self
,
*
,
inputs
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
...
@
overload
def
__init__
(
self
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
__init__
(
self
,
prompt
:
Optional
[
PromptType
]
=
None
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
None
:
if
inputs
is
not
None
:
prompt
=
inputs
assert
(
prompt
is
not
None
and
params
is
not
None
and
request_id
is
not
None
)
super
().
__init__
()
self
.
prompt
=
prompt
self
.
params
=
params
self
.
request_id
=
request_id
self
.
lora_request
=
lora_request
self
.
trace_headers
=
trace_headers
self
.
prompt_adapter_request
=
prompt_adapter_request
@
dataclass
class
RPCError
:
...
...
vllm/engine/multiprocessing/client.py
View file @
4f1ba084
...
...
@@ -3,7 +3,7 @@ import copy
import
pickle
from
contextlib
import
contextmanager
,
suppress
from
typing
import
(
Any
,
AsyncGenerator
,
Dict
,
Iterator
,
Mapping
,
Optional
,
Union
,
overload
)
Union
)
import
cloudpickle
import
zmq
...
...
@@ -25,14 +25,13 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCUProfileRequest
)
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.inputs
import
Prompt
Type
from
vllm.inputs
import
Prompt
Inputs
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.utils
import
deprecate_kwargs
logger
=
init_logger
(
__name__
)
...
...
@@ -368,45 +367,14 @@ class MQLLMEngineClient:
def
dead_error
(
self
)
->
BaseException
:
return
ENGINE_DEAD_ERROR
(
self
.
_errored_with
)
@
overload
# DEPRECATED
def
generate
(
self
,
*
,
inputs
:
PromptType
,
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
@
overload
def
generate
(
self
,
prompt
:
PromptType
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
generate
(
self
,
prompt
:
Optional
[
PromptType
]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
# DEPRECATED
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate outputs for a request.
...
...
@@ -415,7 +383,8 @@ class MQLLMEngineClient:
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
...
...
@@ -424,51 +393,17 @@ class MQLLMEngineClient:
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
"""
if
inputs
is
not
None
:
prompt
=
inputs
assert
(
prompt
is
not
None
and
sampling_params
is
not
None
and
request_id
is
not
None
)
return
self
.
_process_request
(
prompt
,
sampling_params
,
request_id
,
return
self
.
_process_request
(
inputs
,
sampling_params
,
request_id
,
lora_request
,
trace_headers
,
prompt_adapter_request
)
@
overload
# DEPRECATED
def
encode
(
self
,
*
,
inputs
:
PromptType
,
inputs
:
PromptInputs
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
...
@
overload
def
encode
(
self
,
prompt
:
PromptType
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
encode
(
self
,
prompt
:
Optional
[
PromptType
]
=
None
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
# DEPRECATED
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model.
...
...
@@ -477,7 +412,8 @@ class MQLLMEngineClient:
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
...
...
@@ -488,17 +424,12 @@ class MQLLMEngineClient:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
"""
if
inputs
is
not
None
:
prompt
=
inputs
assert
(
prompt
is
not
None
and
pooling_params
is
not
None
and
request_id
is
not
None
)
return
self
.
_process_request
(
prompt
,
pooling_params
,
request_id
,
return
self
.
_process_request
(
inputs
,
pooling_params
,
request_id
,
lora_request
,
trace_headers
)
async
def
_process_request
(
self
,
prompt
:
Prompt
Type
,
inputs
:
Prompt
Inputs
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
...
@@ -531,7 +462,7 @@ class MQLLMEngineClient:
request_bytes
=
pickle
.
dumps
(
RPCProcessRequest
(
prompt
=
prompt
,
inputs
=
inputs
,
params
=
params
,
request_id
=
request_id
,
lora_request
=
lora_request
,
...
...
vllm/engine/multiprocessing/engine.py
View file @
4f1ba084
...
...
@@ -278,7 +278,7 @@ class MQLLMEngine:
try
:
self
.
engine
.
add_request
(
request_id
=
request_id
,
prompt
=
request
.
prompt
,
inputs
=
request
.
inputs
,
params
=
request
.
params
,
lora_request
=
request
.
lora_request
,
trace_headers
=
request
.
trace_headers
,
...
...
vllm/engine/protocol.py
View file @
4f1ba084
...
...
@@ -3,7 +3,7 @@ from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.inputs.data
import
Prompt
Type
from
vllm.inputs.data
import
Prompt
Inputs
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
...
...
@@ -35,19 +35,19 @@ class EngineClient(Protocol):
def
generate
(
self
,
prompt
:
Prompt
Type
,
inputs
:
Prompt
Inputs
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate outputs for a request
.
"""
"""Generate
s
outputs for a request"""
...
def
encode
(
self
,
prompt
:
Prompt
Type
,
inputs
:
Prompt
Inputs
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
...
vllm/entrypoints/llm.py
View file @
4f1ba084
...
...
@@ -12,7 +12,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template
,
apply_mistral_chat_template
,
parse_chat_messages
)
from
vllm.inputs
import
Prompt
Type
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
Prompt
Inputs
,
TextPrompt
,
TokensPrompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
...
...
@@ -293,8 +293,8 @@ class LLM:
@
overload
def
generate
(
self
,
promp
ts
:
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
/
,
inpu
ts
:
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
/
,
# We may enable `inputs` keyword after removing the old API
*
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
Sequence
[
SamplingParams
]]]
=
None
,
...
...
@@ -304,13 +304,14 @@ class LLM:
...
@
deprecate_kwargs
(
"prompts"
,
"prompt_token_ids"
,
is_deprecated
=
lambda
:
LLM
.
DEPRECATE_LEGACY
,
additional_message
=
"Please use the '
promp
ts' parameter instead."
,
additional_message
=
"Please use the '
inpu
ts' parameter instead."
,
)
def
generate
(
self
,
prompts
:
Union
[
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
prompts
:
Union
[
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
Sequence
[
SamplingParams
]]]
=
None
,
...
...
@@ -329,9 +330,7 @@ class LLM:
into a single list and pass it to this method.
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
inputs: A list of inputs to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
...
...
@@ -359,13 +358,12 @@ class LLM:
"models (XForCausalLM, XForConditionalGeneration)."
)
if
prompt_token_ids
is
not
None
:
parsed_promp
ts
=
self
.
_convert_v1_inputs
(
inpu
ts
=
self
.
_convert_v1_inputs
(
prompts
=
cast
(
Optional
[
Union
[
str
,
List
[
str
]]],
prompts
),
prompt_token_ids
=
prompt_token_ids
,
)
else
:
parsed_prompts
=
cast
(
Union
[
PromptType
,
Sequence
[
PromptType
]],
prompts
)
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
if
isinstance
(
guided_options_request
,
dict
):
if
len
(
guided_options_request
)
>
1
:
...
...
@@ -380,7 +378,7 @@ class LLM:
sampling_params
=
SamplingParams
()
self
.
_validate_and_add_requests
(
prompts
=
parsed_promp
ts
,
inputs
=
inpu
ts
,
params
=
sampling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
@@ -650,8 +648,8 @@ class LLM:
@
overload
def
encode
(
self
,
promp
ts
:
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
/
,
inpu
ts
:
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
/
,
# We may enable `inputs` keyword after removing the old API
*
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
...
...
@@ -661,13 +659,14 @@ class LLM:
...
@
deprecate_kwargs
(
"prompts"
,
"prompt_token_ids"
,
is_deprecated
=
lambda
:
LLM
.
DEPRECATE_LEGACY
,
additional_message
=
"Please use the '
promp
ts' parameter instead."
,
additional_message
=
"Please use the '
inpu
ts' parameter instead."
,
)
def
encode
(
self
,
prompts
:
Union
[
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
prompts
:
Union
[
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
...
...
@@ -683,9 +682,9 @@ class LLM:
into a single list and pass it to this method.
Args:
promp
ts: The
promp
ts to the LLM. You may pass a sequence of
prompts
for
batch inference. See :class:`~vllm.inputs.Prompt
Type
`
for more details about the format of each
prompts
.
inpu
ts: The
inpu
ts to the LLM. You may pass a sequence of
inputs for
batch inference. See :class:`~vllm.inputs.Prompt
Inputs
`
for more details about the format of each
input
.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
...
...
@@ -708,20 +707,19 @@ class LLM:
)
if
prompt_token_ids
is
not
None
:
parsed_promp
ts
=
self
.
_convert_v1_inputs
(
inpu
ts
=
self
.
_convert_v1_inputs
(
prompts
=
cast
(
Optional
[
Union
[
str
,
List
[
str
]]],
prompts
),
prompt_token_ids
=
prompt_token_ids
,
)
else
:
parsed_prompts
=
cast
(
Union
[
PromptType
,
Sequence
[
PromptType
]],
prompts
)
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
if
pooling_params
is
None
:
# Use default pooling params.
pooling_params
=
PoolingParams
()
self
.
_validate_and_add_requests
(
prompts
=
parsed_promp
ts
,
inputs
=
inpu
ts
,
params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
@@ -765,9 +763,9 @@ class LLM:
raise
ValueError
(
"Either prompts or prompt_token_ids must be "
"provided."
)
parsed_promp
ts
:
List
[
Prompt
Type
]
=
[]
inpu
ts
:
List
[
Prompt
Inputs
]
=
[]
for
i
in
range
(
num_requests
):
item
:
Prompt
Type
item
:
Prompt
Inputs
if
prompts
is
not
None
:
item
=
TextPrompt
(
prompt
=
prompts
[
i
])
...
...
@@ -776,13 +774,13 @@ class LLM:
else
:
raise
AssertionError
parsed_promp
ts
.
append
(
item
)
inpu
ts
.
append
(
item
)
return
parsed_promp
ts
return
inpu
ts
def
_validate_and_add_requests
(
self
,
promp
ts
:
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
inpu
ts
:
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
params
:
Union
[
SamplingParams
,
Sequence
[
SamplingParams
],
PoolingParams
,
Sequence
[
PoolingParams
]],
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
...
...
@@ -790,11 +788,11 @@ class LLM:
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
,
priority
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
if
isinstance
(
promp
ts
,
(
str
,
dict
)):
if
isinstance
(
inpu
ts
,
(
str
,
dict
)):
# Convert a single prompt to a list.
promp
ts
=
[
promp
ts
]
inpu
ts
=
[
inpu
ts
]
num_requests
=
len
(
promp
ts
)
num_requests
=
len
(
inpu
ts
)
if
isinstance
(
params
,
list
)
and
len
(
params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and params "
"must be the same."
)
...
...
@@ -811,9 +809,9 @@ class LLM:
sp
.
output_kind
=
RequestOutputKind
.
FINAL_ONLY
# Add requests to the engine.
for
i
,
prompt
in
enumerate
(
promp
ts
):
for
i
,
request_inputs
in
enumerate
(
inpu
ts
):
self
.
_add_request
(
prompt
,
request_inputs
,
params
[
i
]
if
isinstance
(
params
,
Sequence
)
else
params
,
lora_request
=
lora_request
[
i
]
if
isinstance
(
lora_request
,
Sequence
)
else
lora_request
,
...
...
@@ -823,7 +821,7 @@ class LLM:
def
_add_request
(
self
,
prompt
:
Prompt
Type
,
inputs
:
Prompt
Inputs
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
...
...
@@ -832,7 +830,7 @@ class LLM:
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_engine
.
add_request
(
request_id
,
prompt
,
inputs
,
params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
vllm/inputs/__init__.py
View file @
4f1ba084
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
LLMInputs
,
Prompt
Type
,
SingletonPrompt
,
TextPrompt
,
LLMInputs
,
Prompt
Inputs
,
SingletonPrompt
Inputs
,
TextPrompt
,
TokensPrompt
,
build_explicit_enc_dec_prompt
,
to_enc_dec_tuple_list
,
zip_enc_dec_prompts
)
from
.registry
import
InputContext
,
InputRegistry
...
...
@@ -16,8 +16,8 @@ See also:
__all__
=
[
"TextPrompt"
,
"TokensPrompt"
,
"Prompt
Type
"
,
"SingletonPrompt"
,
"Prompt
Inputs
"
,
"SingletonPrompt
Inputs
"
,
"ExplicitEncoderDecoderPrompt"
,
"LLMInputs"
,
"EncoderDecoderLLMInputs"
,
...
...
@@ -28,17 +28,3 @@ __all__ = [
"InputContext"
,
"InputRegistry"
,
]
def
__getattr__
(
name
:
str
):
if
name
==
"PromptInput"
:
import
warnings
msg
=
(
"PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version."
)
warnings
.
warn
(
DeprecationWarning
(
msg
),
stacklevel
=
2
)
return
PromptType
raise
AttributeError
(
f
"module
{
__name__
!
r
}
has no attribute
{
name
!
r
}
"
)
vllm/inputs/data.py
View file @
4f1ba084
...
...
@@ -33,7 +33,7 @@ class TokensPrompt(TypedDict):
"""
SingletonPrompt
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
SingletonPrompt
Inputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
"""
Set of possible schemas for a single LLM input:
...
...
@@ -46,7 +46,7 @@ which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
A prompt of type :class:`SingletonPrompt` may be employed
A prompt of type :class:`SingletonPrompt
Inputs
` may be employed
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
...
...
@@ -55,33 +55,33 @@ more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
"""
_T1_co
=
TypeVar
(
"_T1_co"
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
,
bound
=
SingletonPrompt
Inputs
,
default
=
SingletonPrompt
Inputs
,
covariant
=
True
)
_T2_co
=
TypeVar
(
"_T2_co"
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
,
bound
=
SingletonPrompt
Inputs
,
default
=
SingletonPrompt
Inputs
,
covariant
=
True
)
# TODO: Make fields ReadOnly once mypy supports it
class
ExplicitEncoderDecoderPrompt
(
TypedDict
,
Generic
[
_T1_co
,
_T2_co
]):
"""
Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a
decoder prompt.
"""
Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a
decoder prompt.
The encoder and decoder prompts, respectively,
may formatted according to any of the
:class:`SingletonPrompt` schemas, and are not
:class:`SingletonPrompt
Inputs
` schemas, and are not
required to have the same schema.
Only the encoder prompt may have multi-modal data.
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model,
and that the
:code:
`encoder_prompt` and
:code:
`decoder_prompt`
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure themselves must be
:class:`SingletonPrompt` instances.
:class:`SingletonPrompt
Inputs
` instances.
"""
encoder_prompt
:
_T1_co
...
...
@@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
decoder_prompt
:
Optional
[
_T2_co
]
Prompt
Type
=
Union
[
SingletonPrompt
,
ExplicitEncoderDecoderPrompt
]
Prompt
Inputs
=
Union
[
SingletonPrompt
Inputs
,
ExplicitEncoderDecoderPrompt
]
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
...
...
@@ -140,8 +140,12 @@ class EncoderDecoderLLMInputs(LLMInputs):
"""
_T1
=
TypeVar
(
"_T1"
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
)
_T2
=
TypeVar
(
"_T2"
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
)
_T1
=
TypeVar
(
"_T1"
,
bound
=
SingletonPromptInputs
,
default
=
SingletonPromptInputs
)
_T2
=
TypeVar
(
"_T2"
,
bound
=
SingletonPromptInputs
,
default
=
SingletonPromptInputs
)
def
build_explicit_enc_dec_prompt
(
...
...
@@ -172,17 +176,3 @@ def to_enc_dec_tuple_list(
return
[(
enc_dec_prompt
[
"encoder_prompt"
],
enc_dec_prompt
[
"decoder_prompt"
])
for
enc_dec_prompt
in
enc_dec_prompts
]
def
__getattr__
(
name
:
str
):
if
name
==
"PromptInput"
:
import
warnings
msg
=
(
"PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version."
)
warnings
.
warn
(
DeprecationWarning
(
msg
),
stacklevel
=
2
)
return
PromptType
raise
AttributeError
(
f
"module
{
__name__
!
r
}
has no attribute
{
name
!
r
}
"
)
vllm/inputs/parse.py
View file @
4f1ba084
...
...
@@ -5,7 +5,7 @@ from typing_extensions import TypeIs
from
vllm.utils
import
is_list_of
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
LLMInputs
,
Prompt
Type
,
SingletonPrompt
,
TextPrompt
,
LLMInputs
,
Prompt
Inputs
,
SingletonPrompt
Inputs
,
TextPrompt
,
TokensPrompt
)
...
...
@@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict):
def
parse_singleton_prompt
(
prompt
:
SingletonPrompt
,
inputs
:
SingletonPrompt
Inputs
,
)
->
Union
[
ParsedStrPrompt
,
ParsedTextPrompt
,
ParsedTokensPrompt
]:
if
isinstance
(
prompt
,
str
):
return
ParsedStrPrompt
(
type
=
"str"
,
content
=
prompt
)
elif
isinstance
(
prompt
,
dict
):
if
"prompt_token_ids"
in
prompt
:
if
isinstance
(
inputs
,
str
):
return
ParsedStrPrompt
(
type
=
"str"
,
content
=
inputs
)
elif
isinstance
(
inputs
,
dict
):
if
"prompt_token_ids"
in
inputs
:
return
ParsedTokensPrompt
(
type
=
"tokens"
,
content
=
prompt
)
# type: ignore
elif
"prompt"
in
prompt
:
return
ParsedTextPrompt
(
type
=
"text"
,
content
=
prompt
)
content
=
inputs
)
# type: ignore
elif
"prompt"
in
inputs
:
return
ParsedTextPrompt
(
type
=
"text"
,
content
=
inputs
)
raise
TypeError
(
"inputs must be a string, TextPrompt, or TokensPrompt"
)
def
is_explicit_encoder_decoder_prompt
(
prompt
:
Prompt
Type
)
->
TypeIs
[
ExplicitEncoderDecoderPrompt
]:
return
isinstance
(
prompt
,
dict
)
and
"encoder_prompt"
in
prompt
inputs
:
Prompt
Inputs
)
->
TypeIs
[
ExplicitEncoderDecoderPrompt
]:
return
isinstance
(
inputs
,
dict
)
and
"encoder_prompt"
in
inputs
def
is_valid_encoder_decoder_llm_inputs
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment