Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3b00b9c2
Unverified
Commit
3b00b9c2
authored
Sep 27, 2024
by
Cyrus Leung
Committed by
GitHub
Sep 26, 2024
Browse files
[Core] rename`PromptInputs` and `inputs` (#8876)
parent
344cd2b6
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
397 additions
and
205 deletions
+397
-205
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+4
-4
docs/source/dev/multimodal/multimodal_index.rst
docs/source/dev/multimodal/multimodal_index.rst
+1
-1
docs/source/dev/offline_inference/llm_inputs.rst
docs/source/dev/offline_inference/llm_inputs.rst
+1
-1
docs/source/models/vlm.rst
docs/source/models/vlm.rst
+1
-1
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+5
-3
tests/entrypoints/llm/test_encode.py
tests/entrypoints/llm/test_encode.py
+0
-34
tests/entrypoints/llm/test_generate.py
tests/entrypoints/llm/test_generate.py
+0
-37
tests/mq_llm_engine/test_error_handling.py
tests/mq_llm_engine/test_error_handling.py
+6
-6
tests/mq_llm_engine/utils.py
tests/mq_llm_engine/utils.py
+1
-1
vllm/__init__.py
vllm/__init__.py
+2
-2
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+92
-18
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+45
-7
vllm/engine/multiprocessing/__init__.py
vllm/engine/multiprocessing/__init__.py
+58
-3
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+82
-13
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+1
-1
vllm/engine/protocol.py
vllm/engine/protocol.py
+4
-4
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+35
-33
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+17
-3
vllm/inputs/data.py
vllm/inputs/data.py
+31
-22
vllm/inputs/parse.py
vllm/inputs/parse.py
+11
-11
No files found.
benchmarks/benchmark_latency.py
View file @
3b00b9c2
...
@@ -11,7 +11,7 @@ from tqdm import tqdm
...
@@ -11,7 +11,7 @@ from tqdm import tqdm
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.engine.arg_utils
import
DEVICE_OPTIONS
,
EngineArgs
from
vllm.engine.arg_utils
import
DEVICE_OPTIONS
,
EngineArgs
from
vllm.inputs
import
Prompt
Inputs
from
vllm.inputs
import
Prompt
Type
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
...
@@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
...
@@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids
=
np
.
random
.
randint
(
10000
,
dummy_prompt_token_ids
=
np
.
random
.
randint
(
10000
,
size
=
(
args
.
batch_size
,
size
=
(
args
.
batch_size
,
args
.
input_len
))
args
.
input_len
))
dummy_
inpu
ts
:
List
[
Prompt
Inputs
]
=
[{
dummy_
promp
ts
:
List
[
Prompt
Type
]
=
[{
"prompt_token_ids"
:
batch
"prompt_token_ids"
:
batch
}
for
batch
in
dummy_prompt_token_ids
.
tolist
()]
}
for
batch
in
dummy_prompt_token_ids
.
tolist
()]
...
@@ -74,13 +74,13 @@ def main(args: argparse.Namespace):
...
@@ -74,13 +74,13 @@ def main(args: argparse.Namespace):
],
],
on_trace_ready
=
torch
.
profiler
.
tensorboard_trace_handler
(
on_trace_ready
=
torch
.
profiler
.
tensorboard_trace_handler
(
str
(
profile_dir
)))
as
p
:
str
(
profile_dir
)))
as
p
:
llm
.
generate
(
dummy_
inpu
ts
,
llm
.
generate
(
dummy_
promp
ts
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
use_tqdm
=
False
)
use_tqdm
=
False
)
print
(
p
.
key_averages
())
print
(
p
.
key_averages
())
else
:
else
:
start_time
=
time
.
perf_counter
()
start_time
=
time
.
perf_counter
()
llm
.
generate
(
dummy_
inpu
ts
,
llm
.
generate
(
dummy_
promp
ts
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
use_tqdm
=
False
)
use_tqdm
=
False
)
end_time
=
time
.
perf_counter
()
end_time
=
time
.
perf_counter
()
...
...
docs/source/dev/multimodal/multimodal_index.rst
View file @
3b00b9c2
...
@@ -8,7 +8,7 @@ Multi-Modality
...
@@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.Prompt
Inputs
`.
via the ``multi_modal_data`` field in :class:`vllm.inputs.Prompt
Type
`.
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`.
by following :ref:`this guide <adding_multimodal_plugin>`.
...
...
docs/source/dev/offline_inference/llm_inputs.rst
View file @
3b00b9c2
LLM Inputs
LLM Inputs
==========
==========
.. autodata:: vllm.inputs.Prompt
Inputs
.. autodata:: vllm.inputs.Prompt
Type
.. autoclass:: vllm.inputs.TextPrompt
.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
:show-inheritance:
...
...
docs/source/models/vlm.rst
View file @
3b00b9c2
...
@@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
...
@@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
To pass an image to the model, note the following in :class:`vllm.inputs.Prompt
Inputs
`:
To pass an image to the model, note the following in :class:`vllm.inputs.Prompt
Type
`:
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
...
...
tests/async_engine/test_async_llm_engine.py
View file @
3b00b9c2
...
@@ -86,17 +86,19 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
...
@@ -86,17 +86,19 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_new_requests_event
():
async
def
test_new_requests_event
():
params
=
SamplingParams
()
engine
=
MockAsyncLLMEngine
()
engine
=
MockAsyncLLMEngine
()
engine
.
start_background_loop
()
engine
.
start_background_loop
()
await
asyncio
.
sleep
(
0.01
)
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
step_calls
==
0
assert
engine
.
engine
.
step_calls
==
0
await
engine
.
add_request
(
"1"
,
""
,
None
)
await
engine
.
add_request
(
"1"
,
""
,
params
)
await
asyncio
.
sleep
(
0.01
)
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
add_request_calls
==
1
assert
engine
.
engine
.
add_request_calls
==
1
assert
engine
.
engine
.
step_calls
==
1
assert
engine
.
engine
.
step_calls
==
1
await
engine
.
add_request
(
"2"
,
""
,
None
)
await
engine
.
add_request
(
"2"
,
""
,
params
)
engine
.
engine
.
generate
(
"2"
)
engine
.
engine
.
generate
(
"2"
)
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
)
...
@@ -111,7 +113,7 @@ async def test_new_requests_event():
...
@@ -111,7 +113,7 @@ async def test_new_requests_event():
await
asyncio
.
sleep
(
0.001
)
await
asyncio
.
sleep
(
0.001
)
assert
engine
.
engine
.
step_calls
==
old_step_calls
assert
engine
.
engine
.
step_calls
==
old_step_calls
await
engine
.
add_request
(
"3"
,
""
,
None
)
await
engine
.
add_request
(
"3"
,
""
,
params
)
await
asyncio
.
sleep
(
0.01
)
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
step_calls
==
old_step_calls
+
1
assert
engine
.
engine
.
step_calls
==
old_step_calls
+
1
...
...
tests/entrypoints/llm/test_encode.py
View file @
3b00b9c2
...
@@ -49,21 +49,6 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
...
@@ -49,21 +49,6 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
assert
[
o
.
outputs
for
o
in
o1
]
==
[
o
.
outputs
for
o
in
o2
]
assert
[
o
.
outputs
for
o
in
o1
]
==
[
o
.
outputs
for
o
in
o2
]
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
'prompt'
,
PROMPTS
)
def
test_v1_v2_api_consistency_single_prompt_string
(
llm
:
LLM
,
prompt
):
pooling_params
=
PoolingParams
()
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompts'"
):
v1_output
=
llm
.
encode
(
prompts
=
prompt
,
pooling_params
=
pooling_params
)
v2_output
=
llm
.
encode
(
prompt
,
pooling_params
=
pooling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
v2_output
=
llm
.
encode
({
"prompt"
:
prompt
},
pooling_params
=
pooling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
'prompt_token_ids'
,
TOKEN_IDS
)
@
pytest
.
mark
.
parametrize
(
'prompt_token_ids'
,
TOKEN_IDS
)
def
test_v1_v2_api_consistency_single_prompt_tokens
(
llm
:
LLM
,
def
test_v1_v2_api_consistency_single_prompt_tokens
(
llm
:
LLM
,
...
@@ -79,25 +64,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
...
@@ -79,25 +64,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
assert_outputs_equal
(
v1_output
,
v2_output
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_v1_v2_api_consistency_multi_prompt_string
(
llm
:
LLM
):
pooling_params
=
PoolingParams
()
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompts'"
):
v1_output
=
llm
.
encode
(
prompts
=
PROMPTS
,
pooling_params
=
pooling_params
)
v2_output
=
llm
.
encode
(
PROMPTS
,
pooling_params
=
pooling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
v2_output
=
llm
.
encode
(
[{
"prompt"
:
p
}
for
p
in
PROMPTS
],
pooling_params
=
pooling_params
,
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
def
test_v1_v2_api_consistency_multi_prompt_tokens
(
llm
:
LLM
):
def
test_v1_v2_api_consistency_multi_prompt_tokens
(
llm
:
LLM
):
pooling_params
=
PoolingParams
()
pooling_params
=
PoolingParams
()
...
...
tests/entrypoints/llm/test_generate.py
View file @
3b00b9c2
...
@@ -47,23 +47,6 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
...
@@ -47,23 +47,6 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
assert
[
o
.
outputs
for
o
in
o1
]
==
[
o
.
outputs
for
o
in
o2
]
assert
[
o
.
outputs
for
o
in
o1
]
==
[
o
.
outputs
for
o
in
o2
]
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
'prompt'
,
PROMPTS
)
def
test_v1_v2_api_consistency_single_prompt_string
(
llm
:
LLM
,
prompt
):
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
)
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompts'"
):
v1_output
=
llm
.
generate
(
prompts
=
prompt
,
sampling_params
=
sampling_params
)
v2_output
=
llm
.
generate
(
prompt
,
sampling_params
=
sampling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
v2_output
=
llm
.
generate
({
"prompt"
:
prompt
},
sampling_params
=
sampling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
'prompt_token_ids'
,
TOKEN_IDS
)
@
pytest
.
mark
.
parametrize
(
'prompt_token_ids'
,
TOKEN_IDS
)
def
test_v1_v2_api_consistency_single_prompt_tokens
(
llm
:
LLM
,
def
test_v1_v2_api_consistency_single_prompt_tokens
(
llm
:
LLM
,
...
@@ -79,26 +62,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
...
@@ -79,26 +62,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
assert_outputs_equal
(
v1_output
,
v2_output
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_v1_v2_api_consistency_multi_prompt_string
(
llm
:
LLM
):
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
)
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompts'"
):
v1_output
=
llm
.
generate
(
prompts
=
PROMPTS
,
sampling_params
=
sampling_params
)
v2_output
=
llm
.
generate
(
PROMPTS
,
sampling_params
=
sampling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
v2_output
=
llm
.
generate
(
[{
"prompt"
:
p
}
for
p
in
PROMPTS
],
sampling_params
=
sampling_params
,
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
def
test_v1_v2_api_consistency_multi_prompt_tokens
(
llm
:
LLM
):
def
test_v1_v2_api_consistency_multi_prompt_tokens
(
llm
:
LLM
):
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
)
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
)
...
...
tests/mq_llm_engine/test_error_handling.py
View file @
3b00b9c2
...
@@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket):
...
@@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket):
# Throws an error in first forward pass.
# Throws an error in first forward pass.
with
pytest
.
raises
(
RAISED_ERROR
):
with
pytest
.
raises
(
RAISED_ERROR
):
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
sampling_params
=
SamplingParams
(),
request_id
=
uuid
.
uuid4
()):
request_id
=
uuid
.
uuid4
()):
pass
pass
...
@@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket):
...
@@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket):
# Engine is errored, should get ENGINE_DEAD_ERROR.
# Engine is errored, should get ENGINE_DEAD_ERROR.
with
pytest
.
raises
(
MQEngineDeadError
):
with
pytest
.
raises
(
MQEngineDeadError
):
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
sampling_params
=
SamplingParams
(),
request_id
=
uuid
.
uuid4
()):
request_id
=
uuid
.
uuid4
()):
pass
pass
...
@@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):
...
@@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):
# Generate call should throw ENGINE_DEAD_ERROR
# Generate call should throw ENGINE_DEAD_ERROR
with
pytest
.
raises
(
MQEngineDeadError
):
with
pytest
.
raises
(
MQEngineDeadError
):
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
sampling_params
=
SamplingParams
(),
request_id
=
uuid
.
uuid4
()):
request_id
=
uuid
.
uuid4
()):
pass
pass
...
@@ -160,7 +160,7 @@ async def test_failed_abort(tmp_socket):
...
@@ -160,7 +160,7 @@ async def test_failed_abort(tmp_socket):
# with reference to the original KeyError("foo")
# with reference to the original KeyError("foo")
with
pytest
.
raises
(
MQEngineDeadError
)
as
execinfo
:
with
pytest
.
raises
(
MQEngineDeadError
)
as
execinfo
:
async
for
_
in
client
.
generate
(
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(
max_tokens
=
10
),
sampling_params
=
SamplingParams
(
max_tokens
=
10
),
request_id
=
uuid
.
uuid4
()):
request_id
=
uuid
.
uuid4
()):
pass
pass
...
@@ -183,7 +183,7 @@ async def test_bad_request(tmp_socket):
...
@@ -183,7 +183,7 @@ async def test_bad_request(tmp_socket):
# Invalid request should fail, but not crash the server.
# Invalid request should fail, but not crash the server.
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
sampling_params
=
SamplingParams
(),
request_id
=
"abcd-1"
,
request_id
=
"abcd-1"
,
lora_request
=
LoRARequest
(
lora_request
=
LoRARequest
(
...
@@ -192,7 +192,7 @@ async def test_bad_request(tmp_socket):
...
@@ -192,7 +192,7 @@ async def test_bad_request(tmp_socket):
pass
pass
# This request should be okay.
# This request should be okay.
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
sampling_params
=
SamplingParams
(),
request_id
=
"abcd-2"
):
request_id
=
"abcd-2"
):
pass
pass
...
...
tests/mq_llm_engine/utils.py
View file @
3b00b9c2
...
@@ -20,7 +20,7 @@ async def generate(
...
@@ -20,7 +20,7 @@ async def generate(
count
=
0
count
=
0
async
for
out
in
client
.
generate
(
async
for
out
in
client
.
generate
(
request_id
=
request_id
,
request_id
=
request_id
,
inputs
=
"Hello my name is Robert and"
,
prompt
=
"Hello my name is Robert and"
,
sampling_params
=
SamplingParams
(
max_tokens
=
num_tokens
,
sampling_params
=
SamplingParams
(
max_tokens
=
num_tokens
,
temperature
=
0
)):
temperature
=
0
)):
...
...
vllm/__init__.py
View file @
3b00b9c2
...
@@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
...
@@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.entrypoints.llm
import
LLM
from
vllm.entrypoints.llm
import
LLM
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
Prompt
Inputs
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
Prompt
Type
,
TextPrompt
,
TokensPrompt
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.outputs
import
(
CompletionOutput
,
EmbeddingOutput
,
from
vllm.outputs
import
(
CompletionOutput
,
EmbeddingOutput
,
EmbeddingRequestOutput
,
RequestOutput
)
EmbeddingRequestOutput
,
RequestOutput
)
...
@@ -19,7 +19,7 @@ __all__ = [
...
@@ -19,7 +19,7 @@ __all__ = [
"__version_tuple__"
,
"__version_tuple__"
,
"LLM"
,
"LLM"
,
"ModelRegistry"
,
"ModelRegistry"
,
"Prompt
Inputs
"
,
"Prompt
Type
"
,
"TextPrompt"
,
"TextPrompt"
,
"TokensPrompt"
,
"TokensPrompt"
,
"SamplingParams"
,
"SamplingParams"
,
...
...
vllm/engine/async_llm_engine.py
View file @
3b00b9c2
...
@@ -2,8 +2,8 @@ import asyncio
...
@@ -2,8 +2,8 @@ import asyncio
import
time
import
time
import
weakref
import
weakref
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Dict
,
Iterable
,
List
,
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Coroutine
,
Dict
,
Iterable
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
,
overload
)
from
weakref
import
ReferenceType
from
weakref
import
ReferenceType
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -17,7 +17,7 @@ from vllm.engine.metrics_types import StatLoggerBase
...
@@ -17,7 +17,7 @@ from vllm.engine.metrics_types import StatLoggerBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
Prompt
Inputs
from
vllm.inputs
import
Prompt
Type
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
...
@@ -28,7 +28,7 @@ from vllm.sampling_params import SamplingParams
...
@@ -28,7 +28,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
weak_bind
from
vllm.utils
import
deprecate_kwargs
,
weak_bind
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
...
@@ -402,17 +402,54 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -402,17 +402,54 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
"""Stop the remote worker execution loop."""
await
self
.
model_executor
.
stop_remote_worker_execution_loop_async
()
await
self
.
model_executor
.
stop_remote_worker_execution_loop_async
()
@
overload
# DEPRECATED
async
def
add_request_async
(
async
def
add_request_async
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
inputs
:
PromptInputs
,
*
,
inputs
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
...
@
overload
async
def
add_request_async
(
self
,
request_id
:
str
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
async
def
add_request_async
(
self
,
request_id
:
str
,
prompt
:
Optional
[
PromptType
]
=
None
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
None
:
)
->
None
:
"""Async version of :meth:`add_request`."""
"""Async version of :meth:`add_request`."""
if
inputs
is
not
None
:
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
if
lora_request
is
not
None
and
not
self
.
lora_config
:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
"not enabled!"
)
...
@@ -420,7 +457,7 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -420,7 +457,7 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
preprocessed_inputs
=
await
self
.
input_preprocessor
.
preprocess_async
(
preprocessed_inputs
=
await
self
.
input_preprocessor
.
preprocess_async
(
inputs
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
@@ -774,16 +811,55 @@ class AsyncLLMEngine:
...
@@ -774,16 +811,55 @@ class AsyncLLMEngine:
# This method does not need to be async, but kept that way
# This method does not need to be async, but kept that way
# for backwards compatibility.
# for backwards compatibility.
async
def
add_request
(
@
overload
# DEPRECATED
def
add_request
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
inputs
:
PromptInputs
,
*
,
inputs
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
Coroutine
[
None
,
None
,
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]]:
...
@
overload
def
add_request
(
self
,
request_id
:
str
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
Coroutine
[
None
,
None
,
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]]:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
async
def
add_request
(
self
,
request_id
:
str
,
prompt
:
Optional
[
PromptType
]
=
None
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]:
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]:
if
inputs
is
not
None
:
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
if
not
self
.
is_running
:
if
not
self
.
is_running
:
if
self
.
start_engine_loop
:
if
self
.
start_engine_loop
:
self
.
start_background_loop
()
self
.
start_background_loop
()
...
@@ -797,7 +873,7 @@ class AsyncLLMEngine:
...
@@ -797,7 +873,7 @@ class AsyncLLMEngine:
stream
=
self
.
_request_tracker
.
add_request
(
stream
=
self
.
_request_tracker
.
add_request
(
request_id
,
request_id
,
verbose
=
self
.
log_requests
,
verbose
=
self
.
log_requests
,
inputs
=
inputs
,
prompt
=
prompt
,
params
=
params
,
params
=
params
,
arrival_time
=
arrival_time
or
time
.
time
(),
arrival_time
=
arrival_time
or
time
.
time
(),
lora_request
=
lora_request
,
lora_request
=
lora_request
,
...
@@ -808,7 +884,7 @@ class AsyncLLMEngine:
...
@@ -808,7 +884,7 @@ class AsyncLLMEngine:
async
def
generate
(
async
def
generate
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -822,8 +898,7 @@ class AsyncLLMEngine:
...
@@ -822,8 +898,7 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
from the LLMEngine to the caller.
Args:
Args:
inputs: The inputs to the LLM. See
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
request_id: The unique id of the request.
...
@@ -881,7 +956,7 @@ class AsyncLLMEngine:
...
@@ -881,7 +956,7 @@ class AsyncLLMEngine:
"""
"""
async
for
output
in
await
self
.
add_request
(
async
for
output
in
await
self
.
add_request
(
request_id
,
request_id
,
inputs
,
prompt
,
sampling_params
,
sampling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
...
@@ -891,7 +966,7 @@ class AsyncLLMEngine:
...
@@ -891,7 +966,7 @@ class AsyncLLMEngine:
async
def
encode
(
async
def
encode
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
pooling_params
:
PoolingParams
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -904,8 +979,7 @@ class AsyncLLMEngine:
...
@@ -904,8 +979,7 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
from the LLMEngine to the caller.
Args:
Args:
inputs: The inputs to the LLM. See
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
request_id: The unique id of the request.
...
@@ -959,7 +1033,7 @@ class AsyncLLMEngine:
...
@@ -959,7 +1033,7 @@ class AsyncLLMEngine:
"""
"""
async
for
output
in
await
self
.
add_request
(
async
for
output
in
await
self
.
add_request
(
request_id
,
request_id
,
inputs
,
prompt
,
pooling_params
,
pooling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
...
...
vllm/engine/llm_engine.py
View file @
3b00b9c2
...
@@ -6,7 +6,7 @@ from functools import partial
...
@@ -6,7 +6,7 @@ from functools import partial
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Type
,
Union
from
typing
import
Set
,
Type
,
Union
,
overload
import
torch
import
torch
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
...
@@ -29,7 +29,7 @@ from vllm.executor.executor_base import ExecutorBase
...
@@ -29,7 +29,7 @@ from vllm.executor.executor_base import ExecutorBase
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
(
INPUT_REGISTRY
,
EncoderDecoderLLMInputs
,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
EncoderDecoderLLMInputs
,
InputRegistry
,
LLMInputs
,
Prompt
Inputs
)
InputRegistry
,
LLMInputs
,
Prompt
Type
)
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
...
@@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
usage_message
)
from
vllm.utils
import
Counter
,
Device
,
weak_bind
from
vllm.utils
import
Counter
,
Device
,
deprecate_kwargs
,
weak_bind
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -689,16 +689,51 @@ class LLMEngine:
...
@@ -689,16 +689,51 @@ class LLMEngine:
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
self
.
model_executor
.
stop_remote_worker_execution_loop
()
self
.
model_executor
.
stop_remote_worker_execution_loop
()
@
overload
# DEPRECATED
def
add_request
(
def
add_request
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
inputs
:
PromptInputs
,
*
,
inputs
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
None
:
...
@
overload
def
add_request
(
self
,
request_id
:
str
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
add_request
(
self
,
request_id
:
str
,
prompt
:
Optional
[
PromptType
]
=
None
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
None
:
)
->
None
:
"""Add a request to the engine's request pool.
"""Add a request to the engine's request pool.
...
@@ -708,8 +743,7 @@ class LLMEngine:
...
@@ -708,8 +743,7 @@ class LLMEngine:
Args:
Args:
request_id: The unique ID of the request.
request_id: The unique ID of the request.
inputs: The inputs to the LLM. See
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
for more details about the format of each input.
params: Parameters for sampling or pooling.
params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation.
:class:`~vllm.SamplingParams` for text generation.
...
@@ -744,6 +778,10 @@ class LLMEngine:
...
@@ -744,6 +778,10 @@ class LLMEngine:
>>> # continue the request processing
>>> # continue the request processing
>>> ...
>>> ...
"""
"""
if
inputs
is
not
None
:
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
if
lora_request
is
not
None
and
not
self
.
lora_config
:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
"not enabled!"
)
...
@@ -756,7 +794,7 @@ class LLMEngine:
...
@@ -756,7 +794,7 @@ class LLMEngine:
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
inputs
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
vllm/engine/multiprocessing/__init__.py
View file @
3b00b9c2
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
List
,
Mapping
,
Optional
,
Union
from
typing
import
List
,
Mapping
,
Optional
,
Union
,
overload
from
vllm
import
PoolingParams
from
vllm
import
PoolingParams
from
vllm.inputs
import
Prompt
Inputs
from
vllm.inputs
import
Prompt
Type
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
deprecate_kwargs
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
...
@@ -23,13 +24,67 @@ class MQEngineDeadError(RuntimeError):
...
@@ -23,13 +24,67 @@ class MQEngineDeadError(RuntimeError):
@
dataclass
@
dataclass
class
RPCProcessRequest
:
class
RPCProcessRequest
:
inputs
:
Prompt
Inputs
prompt
:
Prompt
Type
params
:
Union
[
SamplingParams
,
PoolingParams
]
params
:
Union
[
SamplingParams
,
PoolingParams
]
request_id
:
str
request_id
:
str
lora_request
:
Optional
[
LoRARequest
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
@
overload
# DEPRECATED
def
__init__
(
self
,
*
,
inputs
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
...
@
overload
def
__init__
(
self
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
__init__
(
self
,
prompt
:
Optional
[
PromptType
]
=
None
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
None
:
if
inputs
is
not
None
:
prompt
=
inputs
assert
(
prompt
is
not
None
and
params
is
not
None
and
request_id
is
not
None
)
super
().
__init__
()
self
.
prompt
=
prompt
self
.
params
=
params
self
.
request_id
=
request_id
self
.
lora_request
=
lora_request
self
.
trace_headers
=
trace_headers
self
.
prompt_adapter_request
=
prompt_adapter_request
@
dataclass
@
dataclass
class
RPCError
:
class
RPCError
:
...
...
vllm/engine/multiprocessing/client.py
View file @
3b00b9c2
...
@@ -3,7 +3,7 @@ import copy
...
@@ -3,7 +3,7 @@ import copy
import
pickle
import
pickle
from
contextlib
import
contextmanager
,
suppress
from
contextlib
import
contextmanager
,
suppress
from
typing
import
(
Any
,
AsyncGenerator
,
Dict
,
Iterator
,
Mapping
,
Optional
,
from
typing
import
(
Any
,
AsyncGenerator
,
Dict
,
Iterator
,
Mapping
,
Optional
,
Union
)
Union
,
overload
)
import
cloudpickle
import
cloudpickle
import
zmq
import
zmq
...
@@ -25,13 +25,14 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
...
@@ -25,13 +25,14 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCUProfileRequest
)
RPCUProfileRequest
)
# yapf: enable
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.inputs
import
Prompt
Inputs
from
vllm.inputs
import
Prompt
Type
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.utils
import
deprecate_kwargs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -367,14 +368,45 @@ class MQLLMEngineClient:
...
@@ -367,14 +368,45 @@ class MQLLMEngineClient:
def
dead_error
(
self
)
->
BaseException
:
def
dead_error
(
self
)
->
BaseException
:
return
ENGINE_DEAD_ERROR
(
self
.
_errored_with
)
return
ENGINE_DEAD_ERROR
(
self
.
_errored_with
)
@
overload
# DEPRECATED
def
generate
(
def
generate
(
self
,
self
,
inputs
:
PromptInputs
,
*
,
inputs
:
PromptType
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
@
overload
def
generate
(
self
,
prompt
:
PromptType
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
generate
(
self
,
prompt
:
Optional
[
PromptType
]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
# DEPRECATED
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate outputs for a request.
"""Generate outputs for a request.
...
@@ -383,8 +415,7 @@ class MQLLMEngineClient:
...
@@ -383,8 +415,7 @@ class MQLLMEngineClient:
from the LLMEngine to the caller.
from the LLMEngine to the caller.
Args:
Args:
inputs: The inputs to the LLM. See
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
request_id: The unique id of the request.
...
@@ -393,17 +424,51 @@ class MQLLMEngineClient:
...
@@ -393,17 +424,51 @@ class MQLLMEngineClient:
prompt_adapter_request: Prompt Adapter request to use
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
for generation, if any.
"""
"""
return
self
.
_process_request
(
inputs
,
sampling_params
,
request_id
,
if
inputs
is
not
None
:
prompt
=
inputs
assert
(
prompt
is
not
None
and
sampling_params
is
not
None
and
request_id
is
not
None
)
return
self
.
_process_request
(
prompt
,
sampling_params
,
request_id
,
lora_request
,
trace_headers
,
lora_request
,
trace_headers
,
prompt_adapter_request
)
prompt_adapter_request
)
@
overload
# DEPRECATED
def
encode
(
def
encode
(
self
,
self
,
inputs
:
PromptInputs
,
*
,
inputs
:
PromptType
,
pooling_params
:
PoolingParams
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
...
@
overload
def
encode
(
self
,
prompt
:
PromptType
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
encode
(
self
,
prompt
:
Optional
[
PromptType
]
=
None
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
# DEPRECATED
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model.
"""Generate outputs for a request from an embedding model.
...
@@ -412,8 +477,7 @@ class MQLLMEngineClient:
...
@@ -412,8 +477,7 @@ class MQLLMEngineClient:
from the LLMEngine to the caller.
from the LLMEngine to the caller.
Args:
Args:
inputs: The inputs to the LLM. See
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
request_id: The unique id of the request.
...
@@ -424,12 +488,17 @@ class MQLLMEngineClient:
...
@@ -424,12 +488,17 @@ class MQLLMEngineClient:
The output `EmbeddingRequestOutput` objects from the LLMEngine
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
for the request.
"""
"""
return
self
.
_process_request
(
inputs
,
pooling_params
,
request_id
,
if
inputs
is
not
None
:
prompt
=
inputs
assert
(
prompt
is
not
None
and
pooling_params
is
not
None
and
request_id
is
not
None
)
return
self
.
_process_request
(
prompt
,
pooling_params
,
request_id
,
lora_request
,
trace_headers
)
lora_request
,
trace_headers
)
async
def
_process_request
(
async
def
_process_request
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -462,7 +531,7 @@ class MQLLMEngineClient:
...
@@ -462,7 +531,7 @@ class MQLLMEngineClient:
request_bytes
=
pickle
.
dumps
(
request_bytes
=
pickle
.
dumps
(
RPCProcessRequest
(
RPCProcessRequest
(
inputs
=
inputs
,
prompt
=
prompt
,
params
=
params
,
params
=
params
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
...
...
vllm/engine/multiprocessing/engine.py
View file @
3b00b9c2
...
@@ -278,7 +278,7 @@ class MQLLMEngine:
...
@@ -278,7 +278,7 @@ class MQLLMEngine:
try
:
try
:
self
.
engine
.
add_request
(
self
.
engine
.
add_request
(
request_id
=
request_id
,
request_id
=
request_id
,
inputs
=
request
.
inputs
,
prompt
=
request
.
prompt
,
params
=
request
.
params
,
params
=
request
.
params
,
lora_request
=
request
.
lora_request
,
lora_request
=
request
.
lora_request
,
trace_headers
=
request
.
trace_headers
,
trace_headers
=
request
.
trace_headers
,
...
...
vllm/engine/protocol.py
View file @
3b00b9c2
...
@@ -3,7 +3,7 @@ from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
...
@@ -3,7 +3,7 @@ from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.inputs.data
import
Prompt
Inputs
from
vllm.inputs.data
import
Prompt
Type
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
...
@@ -35,19 +35,19 @@ class EngineClient(Protocol):
...
@@ -35,19 +35,19 @@ class EngineClient(Protocol):
def
generate
(
def
generate
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate
s
outputs for a request"""
"""Generate outputs for a request
.
"""
...
...
def
encode
(
def
encode
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
pooling_params
:
PoolingParams
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
...
vllm/entrypoints/llm.py
View file @
3b00b9c2
...
@@ -12,7 +12,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
...
@@ -12,7 +12,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template
,
apply_hf_chat_template
,
apply_mistral_chat_template
,
apply_mistral_chat_template
,
parse_chat_messages
)
parse_chat_messages
)
from
vllm.inputs
import
Prompt
Inputs
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
Prompt
Type
,
TextPrompt
,
TokensPrompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -293,8 +293,8 @@ class LLM:
...
@@ -293,8 +293,8 @@ class LLM:
@
overload
@
overload
def
generate
(
def
generate
(
self
,
self
,
inpu
ts
:
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
promp
ts
:
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
/
,
# We may enable `inputs` keyword after removing the old API
/
,
*
,
*
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
Sequence
[
SamplingParams
]]]
=
None
,
Sequence
[
SamplingParams
]]]
=
None
,
...
@@ -304,14 +304,13 @@ class LLM:
...
@@ -304,14 +304,13 @@ class LLM:
...
...
@
deprecate_kwargs
(
@
deprecate_kwargs
(
"prompts"
,
"prompt_token_ids"
,
"prompt_token_ids"
,
is_deprecated
=
lambda
:
LLM
.
DEPRECATE_LEGACY
,
is_deprecated
=
lambda
:
LLM
.
DEPRECATE_LEGACY
,
additional_message
=
"Please use the '
inpu
ts' parameter instead."
,
additional_message
=
"Please use the '
promp
ts' parameter instead."
,
)
)
def
generate
(
def
generate
(
self
,
self
,
prompts
:
Union
[
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
prompts
:
Union
[
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
Sequence
[
SamplingParams
]]]
=
None
,
Sequence
[
SamplingParams
]]]
=
None
,
...
@@ -330,7 +329,9 @@ class LLM:
...
@@ -330,7 +329,9 @@ class LLM:
into a single list and pass it to this method.
into a single list and pass it to this method.
Args:
Args:
inputs: A list of inputs to generate completions for.
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
sampling_params: The sampling parameters for text generation. If
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
When it is a single value, it is applied to every prompt.
...
@@ -358,12 +359,13 @@ class LLM:
...
@@ -358,12 +359,13 @@ class LLM:
"models (XForCausalLM, XForConditionalGeneration)."
)
"models (XForCausalLM, XForConditionalGeneration)."
)
if
prompt_token_ids
is
not
None
:
if
prompt_token_ids
is
not
None
:
inpu
ts
=
self
.
_convert_v1_inputs
(
parsed_promp
ts
=
self
.
_convert_v1_inputs
(
prompts
=
cast
(
Optional
[
Union
[
str
,
List
[
str
]]],
prompts
),
prompts
=
cast
(
Optional
[
Union
[
str
,
List
[
str
]]],
prompts
),
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
)
)
else
:
else
:
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
parsed_prompts
=
cast
(
Union
[
PromptType
,
Sequence
[
PromptType
]],
prompts
)
if
isinstance
(
guided_options_request
,
dict
):
if
isinstance
(
guided_options_request
,
dict
):
if
len
(
guided_options_request
)
>
1
:
if
len
(
guided_options_request
)
>
1
:
...
@@ -378,7 +380,7 @@ class LLM:
...
@@ -378,7 +380,7 @@ class LLM:
sampling_params
=
SamplingParams
()
sampling_params
=
SamplingParams
()
self
.
_validate_and_add_requests
(
self
.
_validate_and_add_requests
(
inputs
=
inpu
ts
,
prompts
=
parsed_promp
ts
,
params
=
sampling_params
,
params
=
sampling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
@@ -648,8 +650,8 @@ class LLM:
...
@@ -648,8 +650,8 @@ class LLM:
@
overload
@
overload
def
encode
(
def
encode
(
self
,
self
,
inpu
ts
:
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
promp
ts
:
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
/
,
# We may enable `inputs` keyword after removing the old API
/
,
*
,
*
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
Sequence
[
PoolingParams
]]]
=
None
,
...
@@ -659,14 +661,13 @@ class LLM:
...
@@ -659,14 +661,13 @@ class LLM:
...
...
@
deprecate_kwargs
(
@
deprecate_kwargs
(
"prompts"
,
"prompt_token_ids"
,
"prompt_token_ids"
,
is_deprecated
=
lambda
:
LLM
.
DEPRECATE_LEGACY
,
is_deprecated
=
lambda
:
LLM
.
DEPRECATE_LEGACY
,
additional_message
=
"Please use the '
inpu
ts' parameter instead."
,
additional_message
=
"Please use the '
promp
ts' parameter instead."
,
)
)
def
encode
(
def
encode
(
self
,
self
,
prompts
:
Union
[
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
prompts
:
Union
[
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
Sequence
[
PoolingParams
]]]
=
None
,
...
@@ -682,9 +683,9 @@ class LLM:
...
@@ -682,9 +683,9 @@ class LLM:
into a single list and pass it to this method.
into a single list and pass it to this method.
Args:
Args:
inpu
ts: The
inpu
ts to the LLM. You may pass a sequence of
inputs for
promp
ts: The
promp
ts to the LLM. You may pass a sequence of
prompts
batch inference. See :class:`~vllm.inputs.Prompt
Inputs
`
for
batch inference. See :class:`~vllm.inputs.Prompt
Type
`
for more details about the format of each
input
.
for more details about the format of each
prompts
.
pooling_params: The pooling parameters for pooling. If None, we
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
use_tqdm: Whether to use tqdm to display the progress bar.
...
@@ -707,19 +708,20 @@ class LLM:
...
@@ -707,19 +708,20 @@ class LLM:
)
)
if
prompt_token_ids
is
not
None
:
if
prompt_token_ids
is
not
None
:
inpu
ts
=
self
.
_convert_v1_inputs
(
parsed_promp
ts
=
self
.
_convert_v1_inputs
(
prompts
=
cast
(
Optional
[
Union
[
str
,
List
[
str
]]],
prompts
),
prompts
=
cast
(
Optional
[
Union
[
str
,
List
[
str
]]],
prompts
),
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
)
)
else
:
else
:
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
parsed_prompts
=
cast
(
Union
[
PromptType
,
Sequence
[
PromptType
]],
prompts
)
if
pooling_params
is
None
:
if
pooling_params
is
None
:
# Use default pooling params.
# Use default pooling params.
pooling_params
=
PoolingParams
()
pooling_params
=
PoolingParams
()
self
.
_validate_and_add_requests
(
self
.
_validate_and_add_requests
(
inputs
=
inpu
ts
,
prompts
=
parsed_promp
ts
,
params
=
pooling_params
,
params
=
pooling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
@@ -763,9 +765,9 @@ class LLM:
...
@@ -763,9 +765,9 @@ class LLM:
raise
ValueError
(
"Either prompts or prompt_token_ids must be "
raise
ValueError
(
"Either prompts or prompt_token_ids must be "
"provided."
)
"provided."
)
inpu
ts
:
List
[
Prompt
Inputs
]
=
[]
parsed_promp
ts
:
List
[
Prompt
Type
]
=
[]
for
i
in
range
(
num_requests
):
for
i
in
range
(
num_requests
):
item
:
Prompt
Inputs
item
:
Prompt
Type
if
prompts
is
not
None
:
if
prompts
is
not
None
:
item
=
TextPrompt
(
prompt
=
prompts
[
i
])
item
=
TextPrompt
(
prompt
=
prompts
[
i
])
...
@@ -774,13 +776,13 @@ class LLM:
...
@@ -774,13 +776,13 @@ class LLM:
else
:
else
:
raise
AssertionError
raise
AssertionError
inpu
ts
.
append
(
item
)
parsed_promp
ts
.
append
(
item
)
return
inpu
ts
return
parsed_promp
ts
def
_validate_and_add_requests
(
def
_validate_and_add_requests
(
self
,
self
,
inpu
ts
:
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
promp
ts
:
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
params
:
Union
[
SamplingParams
,
Sequence
[
SamplingParams
],
PoolingParams
,
params
:
Union
[
SamplingParams
,
Sequence
[
SamplingParams
],
PoolingParams
,
Sequence
[
PoolingParams
]],
Sequence
[
PoolingParams
]],
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
...
@@ -788,11 +790,11 @@ class LLM:
...
@@ -788,11 +790,11 @@ class LLM:
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
,
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
,
priority
:
Optional
[
List
[
int
]]
=
None
,
priority
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
)
->
None
:
if
isinstance
(
inpu
ts
,
(
str
,
dict
)):
if
isinstance
(
promp
ts
,
(
str
,
dict
)):
# Convert a single prompt to a list.
# Convert a single prompt to a list.
inpu
ts
=
[
inpu
ts
]
promp
ts
=
[
promp
ts
]
num_requests
=
len
(
inpu
ts
)
num_requests
=
len
(
promp
ts
)
if
isinstance
(
params
,
list
)
and
len
(
params
)
!=
num_requests
:
if
isinstance
(
params
,
list
)
and
len
(
params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and params "
raise
ValueError
(
"The lengths of prompts and params "
"must be the same."
)
"must be the same."
)
...
@@ -809,9 +811,9 @@ class LLM:
...
@@ -809,9 +811,9 @@ class LLM:
sp
.
output_kind
=
RequestOutputKind
.
FINAL_ONLY
sp
.
output_kind
=
RequestOutputKind
.
FINAL_ONLY
# Add requests to the engine.
# Add requests to the engine.
for
i
,
request_inputs
in
enumerate
(
inpu
ts
):
for
i
,
prompt
in
enumerate
(
promp
ts
):
self
.
_add_request
(
self
.
_add_request
(
request_inputs
,
prompt
,
params
[
i
]
if
isinstance
(
params
,
Sequence
)
else
params
,
params
[
i
]
if
isinstance
(
params
,
Sequence
)
else
params
,
lora_request
=
lora_request
[
i
]
if
isinstance
(
lora_request
=
lora_request
[
i
]
if
isinstance
(
lora_request
,
Sequence
)
else
lora_request
,
lora_request
,
Sequence
)
else
lora_request
,
...
@@ -821,7 +823,7 @@ class LLM:
...
@@ -821,7 +823,7 @@ class LLM:
def
_add_request
(
def
_add_request
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
...
@@ -830,7 +832,7 @@ class LLM:
...
@@ -830,7 +832,7 @@ class LLM:
request_id
=
str
(
next
(
self
.
request_counter
))
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_engine
.
add_request
(
self
.
llm_engine
.
add_request
(
request_id
,
request_id
,
inputs
,
prompt
,
params
,
params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
vllm/inputs/__init__.py
View file @
3b00b9c2
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
LLMInputs
,
Prompt
Inputs
,
SingletonPrompt
Inputs
,
TextPrompt
,
LLMInputs
,
Prompt
Type
,
SingletonPrompt
,
TextPrompt
,
TokensPrompt
,
build_explicit_enc_dec_prompt
,
TokensPrompt
,
build_explicit_enc_dec_prompt
,
to_enc_dec_tuple_list
,
zip_enc_dec_prompts
)
to_enc_dec_tuple_list
,
zip_enc_dec_prompts
)
from
.registry
import
InputContext
,
InputRegistry
from
.registry
import
InputContext
,
InputRegistry
...
@@ -16,8 +16,8 @@ See also:
...
@@ -16,8 +16,8 @@ See also:
__all__
=
[
__all__
=
[
"TextPrompt"
,
"TextPrompt"
,
"TokensPrompt"
,
"TokensPrompt"
,
"Prompt
Inputs
"
,
"Prompt
Type
"
,
"SingletonPrompt
Inputs
"
,
"SingletonPrompt"
,
"ExplicitEncoderDecoderPrompt"
,
"ExplicitEncoderDecoderPrompt"
,
"LLMInputs"
,
"LLMInputs"
,
"EncoderDecoderLLMInputs"
,
"EncoderDecoderLLMInputs"
,
...
@@ -28,3 +28,17 @@ __all__ = [
...
@@ -28,3 +28,17 @@ __all__ = [
"InputContext"
,
"InputContext"
,
"InputRegistry"
,
"InputRegistry"
,
]
]
def
__getattr__
(
name
:
str
):
if
name
==
"PromptInput"
:
import
warnings
msg
=
(
"PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version."
)
warnings
.
warn
(
DeprecationWarning
(
msg
),
stacklevel
=
2
)
return
PromptType
raise
AttributeError
(
f
"module
{
__name__
!
r
}
has no attribute
{
name
!
r
}
"
)
vllm/inputs/data.py
View file @
3b00b9c2
...
@@ -33,7 +33,7 @@ class TokensPrompt(TypedDict):
...
@@ -33,7 +33,7 @@ class TokensPrompt(TypedDict):
"""
"""
SingletonPrompt
Inputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
SingletonPrompt
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
"""
"""
Set of possible schemas for a single LLM input:
Set of possible schemas for a single LLM input:
...
@@ -46,7 +46,7 @@ which may be utilized for encoder/decoder models when
...
@@ -46,7 +46,7 @@ which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
the user desires to express both the encoder & decoder
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
A prompt of type :class:`SingletonPrompt
Inputs
` may be employed
A prompt of type :class:`SingletonPrompt` may be employed
as (1) input to a decoder-only model, (2) input to
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
where the decoder-prompt is not specified explicitly, or
...
@@ -55,33 +55,32 @@ more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
...
@@ -55,33 +55,32 @@ more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
"""
"""
_T1_co
=
TypeVar
(
"_T1_co"
,
_T1_co
=
TypeVar
(
"_T1_co"
,
bound
=
SingletonPrompt
Inputs
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
Inputs
,
default
=
SingletonPrompt
,
covariant
=
True
)
covariant
=
True
)
_T2_co
=
TypeVar
(
"_T2_co"
,
_T2_co
=
TypeVar
(
"_T2_co"
,
bound
=
SingletonPrompt
Inputs
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
Inputs
,
default
=
SingletonPrompt
,
covariant
=
True
)
covariant
=
True
)
# TODO: Make fields ReadOnly once mypy supports it
# TODO: Make fields ReadOnly once mypy supports it
class
ExplicitEncoderDecoderPrompt
(
TypedDict
,
Generic
[
_T1_co
,
_T2_co
]):
class
ExplicitEncoderDecoderPrompt
(
TypedDict
,
Generic
[
_T1_co
,
_T2_co
]):
"""
Represents an encoder/decoder model input prompt,
"""
comprising an explicit encoder prompt and a
Represents an encoder/decoder model input prompt,
decoder prompt.
comprising an explicit encoder prompt and a
decoder prompt.
The encoder and decoder prompts, respectively,
The encoder and decoder prompts, respectively, may be formatted
may formatted according to any of the
according to any of the :class:`SingletonPrompt` schemas,
:class:`SingletonPromptInputs` schemas, and are not
and are not required to have the same schema.
required to have the same schema.
Only the encoder prompt may have multi-modal data.
Only the encoder prompt may have multi-modal data.
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model,
be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
and that the
:code:
`encoder_prompt` and
:code:
`decoder_prompt`
fields of this data structure themselves must be
fields of this data structure themselves must be
:class:`SingletonPrompt
Inputs
` instances.
:class:`SingletonPrompt` instances.
"""
"""
encoder_prompt
:
_T1_co
encoder_prompt
:
_T1_co
...
@@ -89,7 +88,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
...
@@ -89,7 +88,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
decoder_prompt
:
Optional
[
_T2_co
]
decoder_prompt
:
Optional
[
_T2_co
]
Prompt
Inputs
=
Union
[
SingletonPrompt
Inputs
,
ExplicitEncoderDecoderPrompt
]
Prompt
Type
=
Union
[
SingletonPrompt
,
ExplicitEncoderDecoderPrompt
]
"""
"""
Set of possible schemas for an LLM input, including
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
both decoder-only and encoder/decoder input types:
...
@@ -146,12 +145,8 @@ class EncoderDecoderLLMInputs(LLMInputs):
...
@@ -146,12 +145,8 @@ class EncoderDecoderLLMInputs(LLMInputs):
"""
"""
_T1
=
TypeVar
(
"_T1"
,
_T1
=
TypeVar
(
"_T1"
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
)
bound
=
SingletonPromptInputs
,
_T2
=
TypeVar
(
"_T2"
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
)
default
=
SingletonPromptInputs
)
_T2
=
TypeVar
(
"_T2"
,
bound
=
SingletonPromptInputs
,
default
=
SingletonPromptInputs
)
def
build_explicit_enc_dec_prompt
(
def
build_explicit_enc_dec_prompt
(
...
@@ -182,3 +177,17 @@ def to_enc_dec_tuple_list(
...
@@ -182,3 +177,17 @@ def to_enc_dec_tuple_list(
return
[(
enc_dec_prompt
[
"encoder_prompt"
],
return
[(
enc_dec_prompt
[
"encoder_prompt"
],
enc_dec_prompt
[
"decoder_prompt"
])
enc_dec_prompt
[
"decoder_prompt"
])
for
enc_dec_prompt
in
enc_dec_prompts
]
for
enc_dec_prompt
in
enc_dec_prompts
]
def
__getattr__
(
name
:
str
):
if
name
==
"PromptInput"
:
import
warnings
msg
=
(
"PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version."
)
warnings
.
warn
(
DeprecationWarning
(
msg
),
stacklevel
=
2
)
return
PromptType
raise
AttributeError
(
f
"module
{
__name__
!
r
}
has no attribute
{
name
!
r
}
"
)
vllm/inputs/parse.py
View file @
3b00b9c2
...
@@ -5,7 +5,7 @@ from typing_extensions import TypeIs
...
@@ -5,7 +5,7 @@ from typing_extensions import TypeIs
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
LLMInputs
,
Prompt
Inputs
,
SingletonPrompt
Inputs
,
TextPrompt
,
LLMInputs
,
Prompt
Type
,
SingletonPrompt
,
TextPrompt
,
TokensPrompt
)
TokensPrompt
)
...
@@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict):
...
@@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict):
def
parse_singleton_prompt
(
def
parse_singleton_prompt
(
inputs
:
SingletonPrompt
Inputs
,
prompt
:
SingletonPrompt
,
)
->
Union
[
ParsedStrPrompt
,
ParsedTextPrompt
,
ParsedTokensPrompt
]:
)
->
Union
[
ParsedStrPrompt
,
ParsedTextPrompt
,
ParsedTokensPrompt
]:
if
isinstance
(
inputs
,
str
):
if
isinstance
(
prompt
,
str
):
return
ParsedStrPrompt
(
type
=
"str"
,
content
=
inputs
)
return
ParsedStrPrompt
(
type
=
"str"
,
content
=
prompt
)
elif
isinstance
(
inputs
,
dict
):
elif
isinstance
(
prompt
,
dict
):
if
"prompt_token_ids"
in
inputs
:
if
"prompt_token_ids"
in
prompt
:
return
ParsedTokensPrompt
(
type
=
"tokens"
,
return
ParsedTokensPrompt
(
type
=
"tokens"
,
content
=
inputs
)
# type: ignore
content
=
prompt
)
# type: ignore
elif
"prompt"
in
inputs
:
elif
"prompt"
in
prompt
:
return
ParsedTextPrompt
(
type
=
"text"
,
content
=
inputs
)
return
ParsedTextPrompt
(
type
=
"text"
,
content
=
prompt
)
raise
TypeError
(
"inputs must be a string, TextPrompt, or TokensPrompt"
)
raise
TypeError
(
"inputs must be a string, TextPrompt, or TokensPrompt"
)
def
is_explicit_encoder_decoder_prompt
(
def
is_explicit_encoder_decoder_prompt
(
inputs
:
Prompt
Inputs
)
->
TypeIs
[
ExplicitEncoderDecoderPrompt
]:
prompt
:
Prompt
Type
)
->
TypeIs
[
ExplicitEncoderDecoderPrompt
]:
return
isinstance
(
inputs
,
dict
)
and
"encoder_prompt"
in
inputs
return
isinstance
(
prompt
,
dict
)
and
"encoder_prompt"
in
prompt
def
is_valid_encoder_decoder_llm_inputs
(
def
is_valid_encoder_decoder_llm_inputs
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment