Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0057894e
"vllm/vscode:/vscode.git/clone" did not exist on "cc867be19c4c1480ce399bf95db4fd5791c91cbd"
Unverified
Commit
0057894e
authored
Sep 21, 2024
by
Cyrus Leung
Committed by
GitHub
Sep 20, 2024
Browse files
[Core] Rename `PromptInputs` and `inputs`(#8673)
parent
0f961b3c
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
157 additions
and
162 deletions
+157
-162
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+4
-4
docs/source/dev/multimodal/multimodal_index.rst
docs/source/dev/multimodal/multimodal_index.rst
+1
-1
docs/source/dev/offline_inference/llm_inputs.rst
docs/source/dev/offline_inference/llm_inputs.rst
+1
-1
docs/source/models/vlm.rst
docs/source/models/vlm.rst
+1
-1
tests/mq_llm_engine/test_error_handling.py
tests/mq_llm_engine/test_error_handling.py
+6
-6
tests/mq_llm_engine/utils.py
tests/mq_llm_engine/utils.py
+1
-1
vllm/__init__.py
vllm/__init__.py
+2
-2
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+11
-13
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+4
-5
vllm/engine/multiprocessing/__init__.py
vllm/engine/multiprocessing/__init__.py
+2
-2
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+9
-11
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+1
-1
vllm/engine/protocol.py
vllm/engine/protocol.py
+4
-4
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+42
-38
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+3
-3
vllm/inputs/data.py
vllm/inputs/data.py
+11
-15
vllm/inputs/parse.py
vllm/inputs/parse.py
+11
-11
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+43
-43
No files found.
benchmarks/benchmark_latency.py
View file @
0057894e
...
@@ -11,7 +11,7 @@ from tqdm import tqdm
...
@@ -11,7 +11,7 @@ from tqdm import tqdm
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.engine.arg_utils
import
DEVICE_OPTIONS
,
EngineArgs
from
vllm.engine.arg_utils
import
DEVICE_OPTIONS
,
EngineArgs
from
vllm.inputs
import
Prompt
Inputs
from
vllm.inputs
import
Prompt
Type
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
...
@@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
...
@@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids
=
np
.
random
.
randint
(
10000
,
dummy_prompt_token_ids
=
np
.
random
.
randint
(
10000
,
size
=
(
args
.
batch_size
,
size
=
(
args
.
batch_size
,
args
.
input_len
))
args
.
input_len
))
dummy_
inpu
ts
:
List
[
Prompt
Inputs
]
=
[{
dummy_
promp
ts
:
List
[
Prompt
Type
]
=
[{
"prompt_token_ids"
:
batch
"prompt_token_ids"
:
batch
}
for
batch
in
dummy_prompt_token_ids
.
tolist
()]
}
for
batch
in
dummy_prompt_token_ids
.
tolist
()]
...
@@ -74,13 +74,13 @@ def main(args: argparse.Namespace):
...
@@ -74,13 +74,13 @@ def main(args: argparse.Namespace):
],
],
on_trace_ready
=
torch
.
profiler
.
tensorboard_trace_handler
(
on_trace_ready
=
torch
.
profiler
.
tensorboard_trace_handler
(
str
(
profile_dir
)))
as
p
:
str
(
profile_dir
)))
as
p
:
llm
.
generate
(
dummy_
inpu
ts
,
llm
.
generate
(
dummy_
promp
ts
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
use_tqdm
=
False
)
use_tqdm
=
False
)
print
(
p
.
key_averages
())
print
(
p
.
key_averages
())
else
:
else
:
start_time
=
time
.
perf_counter
()
start_time
=
time
.
perf_counter
()
llm
.
generate
(
dummy_
inpu
ts
,
llm
.
generate
(
dummy_
promp
ts
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
use_tqdm
=
False
)
use_tqdm
=
False
)
end_time
=
time
.
perf_counter
()
end_time
=
time
.
perf_counter
()
...
...
docs/source/dev/multimodal/multimodal_index.rst
View file @
0057894e
...
@@ -8,7 +8,7 @@ Multi-Modality
...
@@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.Prompt
Inputs
`.
via the ``multi_modal_data`` field in :class:`vllm.inputs.Prompt
Type
`.
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`.
by following :ref:`this guide <adding_multimodal_plugin>`.
...
...
docs/source/dev/offline_inference/llm_inputs.rst
View file @
0057894e
LLM Inputs
LLM Inputs
==========
==========
.. autodata:: vllm.inputs.Prompt
Inputs
.. autodata:: vllm.inputs.Prompt
Type
.. autoclass:: vllm.inputs.TextPrompt
.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
:show-inheritance:
...
...
docs/source/models/vlm.rst
View file @
0057894e
...
@@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
...
@@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
To pass an image to the model, note the following in :class:`vllm.inputs.Prompt
Inputs
`:
To pass an image to the model, note the following in :class:`vllm.inputs.Prompt
Type
`:
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
...
...
tests/mq_llm_engine/test_error_handling.py
View file @
0057894e
...
@@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket):
...
@@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket):
# Throws an error in first forward pass.
# Throws an error in first forward pass.
with
pytest
.
raises
(
RAISED_ERROR
):
with
pytest
.
raises
(
RAISED_ERROR
):
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
sampling_params
=
SamplingParams
(),
request_id
=
uuid
.
uuid4
()):
request_id
=
uuid
.
uuid4
()):
pass
pass
...
@@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket):
...
@@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket):
# Engine is errored, should get ENGINE_DEAD_ERROR.
# Engine is errored, should get ENGINE_DEAD_ERROR.
with
pytest
.
raises
(
MQEngineDeadError
):
with
pytest
.
raises
(
MQEngineDeadError
):
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
sampling_params
=
SamplingParams
(),
request_id
=
uuid
.
uuid4
()):
request_id
=
uuid
.
uuid4
()):
pass
pass
...
@@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):
...
@@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):
# Generate call should throw ENGINE_DEAD_ERROR
# Generate call should throw ENGINE_DEAD_ERROR
with
pytest
.
raises
(
MQEngineDeadError
):
with
pytest
.
raises
(
MQEngineDeadError
):
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
sampling_params
=
SamplingParams
(),
request_id
=
uuid
.
uuid4
()):
request_id
=
uuid
.
uuid4
()):
pass
pass
...
@@ -165,7 +165,7 @@ async def test_failed_abort(tmp_socket):
...
@@ -165,7 +165,7 @@ async def test_failed_abort(tmp_socket):
# with reference to the original KeyError("foo")
# with reference to the original KeyError("foo")
with
pytest
.
raises
(
MQEngineDeadError
)
as
execinfo
:
with
pytest
.
raises
(
MQEngineDeadError
)
as
execinfo
:
async
for
_
in
client
.
generate
(
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(
max_tokens
=
2000
),
sampling_params
=
SamplingParams
(
max_tokens
=
2000
),
request_id
=
uuid
.
uuid4
()):
request_id
=
uuid
.
uuid4
()):
pass
pass
...
@@ -190,7 +190,7 @@ async def test_bad_request(tmp_socket):
...
@@ -190,7 +190,7 @@ async def test_bad_request(tmp_socket):
# Invalid request should fail, but not crash the server.
# Invalid request should fail, but not crash the server.
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
sampling_params
=
SamplingParams
(),
request_id
=
"abcd-1"
,
request_id
=
"abcd-1"
,
lora_request
=
LoRARequest
(
lora_request
=
LoRARequest
(
...
@@ -199,7 +199,7 @@ async def test_bad_request(tmp_socket):
...
@@ -199,7 +199,7 @@ async def test_bad_request(tmp_socket):
pass
pass
# This request should be okay.
# This request should be okay.
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
sampling_params
=
SamplingParams
(),
request_id
=
"abcd-2"
):
request_id
=
"abcd-2"
):
pass
pass
...
...
tests/mq_llm_engine/utils.py
View file @
0057894e
...
@@ -20,7 +20,7 @@ async def generate(
...
@@ -20,7 +20,7 @@ async def generate(
count
=
0
count
=
0
async
for
out
in
client
.
generate
(
async
for
out
in
client
.
generate
(
request_id
=
request_id
,
request_id
=
request_id
,
inputs
=
"Hello my name is Robert and"
,
prompt
=
"Hello my name is Robert and"
,
sampling_params
=
SamplingParams
(
max_tokens
=
num_tokens
,
sampling_params
=
SamplingParams
(
max_tokens
=
num_tokens
,
temperature
=
0
)):
temperature
=
0
)):
...
...
vllm/__init__.py
View file @
0057894e
...
@@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
...
@@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.entrypoints.llm
import
LLM
from
vllm.entrypoints.llm
import
LLM
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
Prompt
Inputs
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
Prompt
Type
,
TextPrompt
,
TokensPrompt
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.outputs
import
(
CompletionOutput
,
EmbeddingOutput
,
from
vllm.outputs
import
(
CompletionOutput
,
EmbeddingOutput
,
EmbeddingRequestOutput
,
RequestOutput
)
EmbeddingRequestOutput
,
RequestOutput
)
...
@@ -19,7 +19,7 @@ __all__ = [
...
@@ -19,7 +19,7 @@ __all__ = [
"__version__"
,
"__version__"
,
"LLM"
,
"LLM"
,
"ModelRegistry"
,
"ModelRegistry"
,
"Prompt
Inputs
"
,
"Prompt
Type
"
,
"TextPrompt"
,
"TextPrompt"
,
"TokensPrompt"
,
"TokensPrompt"
,
"SamplingParams"
,
"SamplingParams"
,
...
...
vllm/engine/async_llm_engine.py
View file @
0057894e
...
@@ -17,7 +17,7 @@ from vllm.engine.metrics_types import StatLoggerBase
...
@@ -17,7 +17,7 @@ from vllm.engine.metrics_types import StatLoggerBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
Prompt
Inputs
from
vllm.inputs
import
Prompt
Type
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
...
@@ -405,7 +405,7 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -405,7 +405,7 @@ class _AsyncLLMEngine(LLMEngine):
async
def
add_request_async
(
async
def
add_request_async
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -420,7 +420,7 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -420,7 +420,7 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
preprocessed_inputs
=
await
self
.
input_preprocessor
.
preprocess_async
(
preprocessed_inputs
=
await
self
.
input_preprocessor
.
preprocess_async
(
inputs
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
@@ -777,7 +777,7 @@ class AsyncLLMEngine:
...
@@ -777,7 +777,7 @@ class AsyncLLMEngine:
async
def
add_request
(
async
def
add_request
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -797,7 +797,7 @@ class AsyncLLMEngine:
...
@@ -797,7 +797,7 @@ class AsyncLLMEngine:
stream
=
self
.
_request_tracker
.
add_request
(
stream
=
self
.
_request_tracker
.
add_request
(
request_id
,
request_id
,
verbose
=
self
.
log_requests
,
verbose
=
self
.
log_requests
,
inputs
=
inputs
,
prompt
=
prompt
,
params
=
params
,
params
=
params
,
arrival_time
=
arrival_time
or
time
.
time
(),
arrival_time
=
arrival_time
or
time
.
time
(),
lora_request
=
lora_request
,
lora_request
=
lora_request
,
...
@@ -808,7 +808,7 @@ class AsyncLLMEngine:
...
@@ -808,7 +808,7 @@ class AsyncLLMEngine:
async
def
generate
(
async
def
generate
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -822,8 +822,7 @@ class AsyncLLMEngine:
...
@@ -822,8 +822,7 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
from the LLMEngine to the caller.
Args:
Args:
inputs: The inputs to the LLM. See
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
request_id: The unique id of the request.
...
@@ -881,7 +880,7 @@ class AsyncLLMEngine:
...
@@ -881,7 +880,7 @@ class AsyncLLMEngine:
"""
"""
async
for
output
in
await
self
.
add_request
(
async
for
output
in
await
self
.
add_request
(
request_id
,
request_id
,
inputs
,
prompt
,
sampling_params
,
sampling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
...
@@ -891,7 +890,7 @@ class AsyncLLMEngine:
...
@@ -891,7 +890,7 @@ class AsyncLLMEngine:
async
def
encode
(
async
def
encode
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
pooling_params
:
PoolingParams
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -904,8 +903,7 @@ class AsyncLLMEngine:
...
@@ -904,8 +903,7 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
from the LLMEngine to the caller.
Args:
Args:
inputs: The inputs to the LLM. See
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
request_id: The unique id of the request.
...
@@ -959,7 +957,7 @@ class AsyncLLMEngine:
...
@@ -959,7 +957,7 @@ class AsyncLLMEngine:
"""
"""
async
for
output
in
await
self
.
add_request
(
async
for
output
in
await
self
.
add_request
(
request_id
,
request_id
,
inputs
,
prompt
,
pooling_params
,
pooling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
...
...
vllm/engine/llm_engine.py
View file @
0057894e
...
@@ -29,7 +29,7 @@ from vllm.executor.executor_base import ExecutorBase
...
@@ -29,7 +29,7 @@ from vllm.executor.executor_base import ExecutorBase
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
(
INPUT_REGISTRY
,
EncoderDecoderLLMInputs
,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
EncoderDecoderLLMInputs
,
InputRegistry
,
LLMInputs
,
Prompt
Inputs
)
InputRegistry
,
LLMInputs
,
Prompt
Type
)
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -680,7 +680,7 @@ class LLMEngine:
...
@@ -680,7 +680,7 @@ class LLMEngine:
def
add_request
(
def
add_request
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -695,8 +695,7 @@ class LLMEngine:
...
@@ -695,8 +695,7 @@ class LLMEngine:
Args:
Args:
request_id: The unique ID of the request.
request_id: The unique ID of the request.
inputs: The inputs to the LLM. See
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
for more details about the format of each input.
params: Parameters for sampling or pooling.
params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation.
:class:`~vllm.SamplingParams` for text generation.
...
@@ -736,7 +735,7 @@ class LLMEngine:
...
@@ -736,7 +735,7 @@ class LLMEngine:
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
inputs
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
vllm/engine/multiprocessing/__init__.py
View file @
0057894e
...
@@ -3,7 +3,7 @@ from enum import Enum
...
@@ -3,7 +3,7 @@ from enum import Enum
from
typing
import
List
,
Mapping
,
Optional
,
Union
from
typing
import
List
,
Mapping
,
Optional
,
Union
from
vllm
import
PoolingParams
from
vllm
import
PoolingParams
from
vllm.inputs
import
Prompt
Inputs
from
vllm.inputs
import
Prompt
Type
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
...
@@ -23,7 +23,7 @@ class MQEngineDeadError(RuntimeError):
...
@@ -23,7 +23,7 @@ class MQEngineDeadError(RuntimeError):
@
dataclass
@
dataclass
class
RPCProcessRequest
:
class
RPCProcessRequest
:
inputs
:
Prompt
Inputs
prompt
:
Prompt
Type
params
:
Union
[
SamplingParams
,
PoolingParams
]
params
:
Union
[
SamplingParams
,
PoolingParams
]
request_id
:
str
request_id
:
str
lora_request
:
Optional
[
LoRARequest
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
...
...
vllm/engine/multiprocessing/client.py
View file @
0057894e
...
@@ -25,7 +25,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
...
@@ -25,7 +25,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCStartupResponse
)
RPCStartupResponse
)
# yapf: enable
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.inputs
import
Prompt
Inputs
from
vllm.inputs
import
Prompt
Type
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
...
@@ -375,7 +375,7 @@ class MQLLMEngineClient:
...
@@ -375,7 +375,7 @@ class MQLLMEngineClient:
def
generate
(
def
generate
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -389,8 +389,7 @@ class MQLLMEngineClient:
...
@@ -389,8 +389,7 @@ class MQLLMEngineClient:
from the LLMEngine to the caller.
from the LLMEngine to the caller.
Args:
Args:
inputs: The inputs to the LLM. See
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
request_id: The unique id of the request.
...
@@ -399,13 +398,13 @@ class MQLLMEngineClient:
...
@@ -399,13 +398,13 @@ class MQLLMEngineClient:
prompt_adapter_request: Prompt Adapter request to use
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
for generation, if any.
"""
"""
return
self
.
_process_request
(
inputs
,
sampling_params
,
request_id
,
return
self
.
_process_request
(
prompt
,
sampling_params
,
request_id
,
lora_request
,
trace_headers
,
lora_request
,
trace_headers
,
prompt_adapter_request
)
prompt_adapter_request
)
def
encode
(
def
encode
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
pooling_params
:
PoolingParams
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -418,8 +417,7 @@ class MQLLMEngineClient:
...
@@ -418,8 +417,7 @@ class MQLLMEngineClient:
from the LLMEngine to the caller.
from the LLMEngine to the caller.
Args:
Args:
inputs: The inputs to the LLM. See
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
request_id: The unique id of the request.
...
@@ -430,12 +428,12 @@ class MQLLMEngineClient:
...
@@ -430,12 +428,12 @@ class MQLLMEngineClient:
The output `EmbeddingRequestOutput` objects from the LLMEngine
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
for the request.
"""
"""
return
self
.
_process_request
(
inputs
,
pooling_params
,
request_id
,
return
self
.
_process_request
(
prompt
,
pooling_params
,
request_id
,
lora_request
,
trace_headers
)
lora_request
,
trace_headers
)
async
def
_process_request
(
async
def
_process_request
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -468,7 +466,7 @@ class MQLLMEngineClient:
...
@@ -468,7 +466,7 @@ class MQLLMEngineClient:
request_bytes
=
pickle
.
dumps
(
request_bytes
=
pickle
.
dumps
(
RPCProcessRequest
(
RPCProcessRequest
(
inputs
=
inputs
,
prompt
=
prompt
,
params
=
params
,
params
=
params
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
...
...
vllm/engine/multiprocessing/engine.py
View file @
0057894e
...
@@ -245,7 +245,7 @@ class MQLLMEngine:
...
@@ -245,7 +245,7 @@ class MQLLMEngine:
try
:
try
:
self
.
engine
.
add_request
(
self
.
engine
.
add_request
(
request_id
=
request_id
,
request_id
=
request_id
,
inputs
=
request
.
inputs
,
prompt
=
request
.
prompt
,
params
=
request
.
params
,
params
=
request
.
params
,
lora_request
=
request
.
lora_request
,
lora_request
=
request
.
lora_request
,
trace_headers
=
request
.
trace_headers
,
trace_headers
=
request
.
trace_headers
,
...
...
vllm/engine/protocol.py
View file @
0057894e
...
@@ -3,7 +3,7 @@ from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
...
@@ -3,7 +3,7 @@ from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.inputs.data
import
Prompt
Inputs
from
vllm.inputs.data
import
Prompt
Type
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
...
@@ -35,19 +35,19 @@ class EngineClient(Protocol):
...
@@ -35,19 +35,19 @@ class EngineClient(Protocol):
def
generate
(
def
generate
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate
s
outputs for a request"""
"""Generate outputs for a request
.
"""
...
...
def
encode
(
def
encode
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
pooling_params
:
PoolingParams
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
...
vllm/entrypoints/llm.py
View file @
0057894e
...
@@ -10,7 +10,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
...
@@ -10,7 +10,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template
,
apply_hf_chat_template
,
apply_mistral_chat_template
,
apply_mistral_chat_template
,
parse_chat_messages
)
parse_chat_messages
)
from
vllm.inputs
import
Prompt
Inputs
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
Prompt
Type
,
TextPrompt
,
TokensPrompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -258,8 +258,8 @@ class LLM:
...
@@ -258,8 +258,8 @@ class LLM:
@
overload
@
overload
def
generate
(
def
generate
(
self
,
self
,
inpu
ts
:
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
promp
ts
:
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
/
,
# We may enable `inputs` keyword after removing the old API
/
,
*
,
*
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
Sequence
[
SamplingParams
]]]
=
None
,
Sequence
[
SamplingParams
]]]
=
None
,
...
@@ -276,7 +276,7 @@ class LLM:
...
@@ -276,7 +276,7 @@ class LLM:
)
)
def
generate
(
def
generate
(
self
,
self
,
prompts
:
Union
[
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
prompts
:
Union
[
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
Sequence
[
SamplingParams
]]]
=
None
,
Sequence
[
SamplingParams
]]]
=
None
,
...
@@ -294,7 +294,9 @@ class LLM:
...
@@ -294,7 +294,9 @@ class LLM:
into a single list and pass it to this method.
into a single list and pass it to this method.
Args:
Args:
inputs: A list of inputs to generate completions for.
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
sampling_params: The sampling parameters for text generation. If
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
When it is a single value, it is applied to every prompt.
...
@@ -320,12 +322,13 @@ class LLM:
...
@@ -320,12 +322,13 @@ class LLM:
"models (XForCausalLM, XForConditionalGeneration)."
)
"models (XForCausalLM, XForConditionalGeneration)."
)
if
prompt_token_ids
is
not
None
:
if
prompt_token_ids
is
not
None
:
inpu
ts
=
self
.
_convert_v1_inputs
(
parsed_promp
ts
=
self
.
_convert_v1_inputs
(
prompts
=
cast
(
Optional
[
Union
[
str
,
List
[
str
]]],
prompts
),
prompts
=
cast
(
Optional
[
Union
[
str
,
List
[
str
]]],
prompts
),
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
)
)
else
:
else
:
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
parsed_prompts
=
cast
(
Union
[
PromptType
,
Sequence
[
PromptType
]],
prompts
)
if
isinstance
(
guided_options_request
,
dict
):
if
isinstance
(
guided_options_request
,
dict
):
if
len
(
guided_options_request
)
>
1
:
if
len
(
guided_options_request
)
>
1
:
...
@@ -340,7 +343,7 @@ class LLM:
...
@@ -340,7 +343,7 @@ class LLM:
sampling_params
=
SamplingParams
()
sampling_params
=
SamplingParams
()
self
.
_validate_and_add_requests
(
self
.
_validate_and_add_requests
(
inputs
=
inpu
ts
,
prompts
=
parsed_promp
ts
,
params
=
sampling_params
,
params
=
sampling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
@@ -396,9 +399,9 @@ class LLM:
...
@@ -396,9 +399,9 @@ class LLM:
conversation
,
mm_data
=
parse_chat_messages
(
messages
,
model_config
,
conversation
,
mm_data
=
parse_chat_messages
(
messages
,
model_config
,
tokenizer
)
tokenizer
)
prompt
:
Union
[
str
,
List
[
int
]]
prompt
_data
:
Union
[
str
,
List
[
int
]]
if
isinstance
(
tokenizer
,
MistralTokenizer
):
if
isinstance
(
tokenizer
,
MistralTokenizer
):
prompt
=
apply_mistral_chat_template
(
prompt
_data
=
apply_mistral_chat_template
(
tokenizer
,
tokenizer
,
messages
=
messages
,
messages
=
messages
,
chat_template
=
chat_template
,
chat_template
=
chat_template
,
...
@@ -406,7 +409,7 @@ class LLM:
...
@@ -406,7 +409,7 @@ class LLM:
tools
=
tools
,
tools
=
tools
,
)
)
else
:
else
:
prompt
=
apply_hf_chat_template
(
prompt
_data
=
apply_hf_chat_template
(
tokenizer
,
tokenizer
,
conversation
=
conversation
,
conversation
=
conversation
,
chat_template
=
chat_template
,
chat_template
=
chat_template
,
...
@@ -414,17 +417,17 @@ class LLM:
...
@@ -414,17 +417,17 @@ class LLM:
tools
=
tools
,
tools
=
tools
,
)
)
inputs
:
Prompt
Inputs
prompt
:
Prompt
Type
if
is_list_of
(
prompt
,
int
):
if
is_list_of
(
prompt
_data
,
int
):
inputs
=
TokensPrompt
(
prompt_token_ids
=
prompt
)
prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt
_data
)
else
:
else
:
inputs
=
TextPrompt
(
prompt
=
prompt
)
prompt
=
TextPrompt
(
prompt
=
prompt
_data
)
if
mm_data
is
not
None
:
if
mm_data
is
not
None
:
inputs
[
"multi_modal_data"
]
=
mm_data
prompt
[
"multi_modal_data"
]
=
mm_data
return
self
.
generate
(
return
self
.
generate
(
inputs
,
prompt
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
use_tqdm
=
use_tqdm
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
...
@@ -494,8 +497,8 @@ class LLM:
...
@@ -494,8 +497,8 @@ class LLM:
@
overload
@
overload
def
encode
(
def
encode
(
self
,
self
,
inpu
ts
:
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
promp
ts
:
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
/
,
# We may enable `inputs` keyword after removing the old API
/
,
*
,
*
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
Sequence
[
PoolingParams
]]]
=
None
,
...
@@ -512,7 +515,7 @@ class LLM:
...
@@ -512,7 +515,7 @@ class LLM:
)
)
def
encode
(
def
encode
(
self
,
self
,
prompts
:
Union
[
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
prompts
:
Union
[
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
Sequence
[
PoolingParams
]]]
=
None
,
...
@@ -528,9 +531,9 @@ class LLM:
...
@@ -528,9 +531,9 @@ class LLM:
into a single list and pass it to this method.
into a single list and pass it to this method.
Args:
Args:
inpu
ts: The
inpu
ts to the LLM. You may pass a sequence of
inputs for
promp
ts: The
promp
ts to the LLM. You may pass a sequence of
prompts
batch inference. See :class:`~vllm.inputs.Prompt
Inputs
`
for
batch inference. See :class:`~vllm.inputs.Prompt
Type
`
for more details about the format of each
input
.
for more details about the format of each
prompts
.
pooling_params: The pooling parameters for pooling. If None, we
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
use_tqdm: Whether to use tqdm to display the progress bar.
...
@@ -553,19 +556,20 @@ class LLM:
...
@@ -553,19 +556,20 @@ class LLM:
)
)
if
prompt_token_ids
is
not
None
:
if
prompt_token_ids
is
not
None
:
inpu
ts
=
self
.
_convert_v1_inputs
(
parsed_promp
ts
=
self
.
_convert_v1_inputs
(
prompts
=
cast
(
Optional
[
Union
[
str
,
List
[
str
]]],
prompts
),
prompts
=
cast
(
Optional
[
Union
[
str
,
List
[
str
]]],
prompts
),
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
)
)
else
:
else
:
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
parsed_prompts
=
cast
(
Union
[
PromptType
,
Sequence
[
PromptType
]],
prompts
)
if
pooling_params
is
None
:
if
pooling_params
is
None
:
# Use default pooling params.
# Use default pooling params.
pooling_params
=
PoolingParams
()
pooling_params
=
PoolingParams
()
self
.
_validate_and_add_requests
(
self
.
_validate_and_add_requests
(
inputs
=
inpu
ts
,
prompts
=
parsed_promp
ts
,
params
=
pooling_params
,
params
=
pooling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
@@ -609,9 +613,9 @@ class LLM:
...
@@ -609,9 +613,9 @@ class LLM:
raise
ValueError
(
"Either prompts or prompt_token_ids must be "
raise
ValueError
(
"Either prompts or prompt_token_ids must be "
"provided."
)
"provided."
)
inpu
ts
:
List
[
Prompt
Inputs
]
=
[]
parsed_promp
ts
:
List
[
Prompt
Type
]
=
[]
for
i
in
range
(
num_requests
):
for
i
in
range
(
num_requests
):
item
:
Prompt
Inputs
item
:
Prompt
Type
if
prompts
is
not
None
:
if
prompts
is
not
None
:
item
=
TextPrompt
(
prompt
=
prompts
[
i
])
item
=
TextPrompt
(
prompt
=
prompts
[
i
])
...
@@ -620,24 +624,24 @@ class LLM:
...
@@ -620,24 +624,24 @@ class LLM:
else
:
else
:
raise
AssertionError
raise
AssertionError
inpu
ts
.
append
(
item
)
parsed_promp
ts
.
append
(
item
)
return
inpu
ts
return
parsed_promp
ts
def
_validate_and_add_requests
(
def
_validate_and_add_requests
(
self
,
self
,
inpu
ts
:
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
promp
ts
:
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
params
:
Union
[
SamplingParams
,
Sequence
[
SamplingParams
],
PoolingParams
,
params
:
Union
[
SamplingParams
,
Sequence
[
SamplingParams
],
PoolingParams
,
Sequence
[
PoolingParams
]],
Sequence
[
PoolingParams
]],
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
,
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
,
)
->
None
:
)
->
None
:
if
isinstance
(
inpu
ts
,
(
str
,
dict
)):
if
isinstance
(
promp
ts
,
(
str
,
dict
)):
# Convert a single prompt to a list.
# Convert a single prompt to a list.
inpu
ts
=
[
inpu
ts
]
promp
ts
=
[
promp
ts
]
num_requests
=
len
(
inpu
ts
)
num_requests
=
len
(
promp
ts
)
if
isinstance
(
params
,
list
)
and
len
(
params
)
!=
num_requests
:
if
isinstance
(
params
,
list
)
and
len
(
params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and params "
raise
ValueError
(
"The lengths of prompts and params "
"must be the same."
)
"must be the same."
)
...
@@ -654,9 +658,9 @@ class LLM:
...
@@ -654,9 +658,9 @@ class LLM:
sp
.
output_kind
=
RequestOutputKind
.
FINAL_ONLY
sp
.
output_kind
=
RequestOutputKind
.
FINAL_ONLY
# Add requests to the engine.
# Add requests to the engine.
for
i
,
request_inputs
in
enumerate
(
inpu
ts
):
for
i
,
prompt
in
enumerate
(
promp
ts
):
self
.
_add_request
(
self
.
_add_request
(
request_inputs
,
prompt
,
params
[
i
]
if
isinstance
(
params
,
Sequence
)
else
params
,
params
[
i
]
if
isinstance
(
params
,
Sequence
)
else
params
,
lora_request
=
lora_request
[
i
]
if
isinstance
(
lora_request
=
lora_request
[
i
]
if
isinstance
(
lora_request
,
Sequence
)
else
lora_request
,
lora_request
,
Sequence
)
else
lora_request
,
...
@@ -665,7 +669,7 @@ class LLM:
...
@@ -665,7 +669,7 @@ class LLM:
def
_add_request
(
def
_add_request
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
...
@@ -673,7 +677,7 @@ class LLM:
...
@@ -673,7 +677,7 @@ class LLM:
request_id
=
str
(
next
(
self
.
request_counter
))
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_engine
.
add_request
(
self
.
llm_engine
.
add_request
(
request_id
,
request_id
,
inputs
,
prompt
,
params
,
params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
vllm/inputs/__init__.py
View file @
0057894e
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
LLMInputs
,
Prompt
Inputs
,
SingletonPrompt
Inputs
,
TextPrompt
,
LLMInputs
,
Prompt
Type
,
SingletonPrompt
,
TextPrompt
,
TokensPrompt
,
build_explicit_enc_dec_prompt
,
TokensPrompt
,
build_explicit_enc_dec_prompt
,
to_enc_dec_tuple_list
,
zip_enc_dec_prompts
)
to_enc_dec_tuple_list
,
zip_enc_dec_prompts
)
from
.registry
import
InputContext
,
InputRegistry
from
.registry
import
InputContext
,
InputRegistry
...
@@ -16,8 +16,8 @@ See also:
...
@@ -16,8 +16,8 @@ See also:
__all__
=
[
__all__
=
[
"TextPrompt"
,
"TextPrompt"
,
"TokensPrompt"
,
"TokensPrompt"
,
"Prompt
Inputs
"
,
"Prompt
Type
"
,
"SingletonPrompt
Inputs
"
,
"SingletonPrompt"
,
"ExplicitEncoderDecoderPrompt"
,
"ExplicitEncoderDecoderPrompt"
,
"LLMInputs"
,
"LLMInputs"
,
"EncoderDecoderLLMInputs"
,
"EncoderDecoderLLMInputs"
,
...
...
vllm/inputs/data.py
View file @
0057894e
...
@@ -33,7 +33,7 @@ class TokensPrompt(TypedDict):
...
@@ -33,7 +33,7 @@ class TokensPrompt(TypedDict):
"""
"""
SingletonPrompt
Inputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
SingletonPrompt
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
"""
"""
Set of possible schemas for a single LLM input:
Set of possible schemas for a single LLM input:
...
@@ -46,7 +46,7 @@ which may be utilized for encoder/decoder models when
...
@@ -46,7 +46,7 @@ which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
the user desires to express both the encoder & decoder
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
A prompt of type :class:`SingletonPrompt
Inputs
` may be employed
A prompt of type :class:`SingletonPrompt
Type
` may be employed
as (1) input to a decoder-only model, (2) input to
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
where the decoder-prompt is not specified explicitly, or
...
@@ -55,12 +55,12 @@ more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
...
@@ -55,12 +55,12 @@ more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
"""
"""
_T1_co
=
TypeVar
(
"_T1_co"
,
_T1_co
=
TypeVar
(
"_T1_co"
,
bound
=
SingletonPrompt
Inputs
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
Inputs
,
default
=
SingletonPrompt
,
covariant
=
True
)
covariant
=
True
)
_T2_co
=
TypeVar
(
"_T2_co"
,
_T2_co
=
TypeVar
(
"_T2_co"
,
bound
=
SingletonPrompt
Inputs
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
Inputs
,
default
=
SingletonPrompt
,
covariant
=
True
)
covariant
=
True
)
...
@@ -72,7 +72,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
...
@@ -72,7 +72,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
The encoder and decoder prompts, respectively,
The encoder and decoder prompts, respectively,
may formatted according to any of the
may formatted according to any of the
:class:`SingletonPrompt
Inputs
` schemas, and are not
:class:`SingletonPrompt
Type
` schemas, and are not
required to have the same schema.
required to have the same schema.
Only the encoder prompt may have multi-modal data.
Only the encoder prompt may have multi-modal data.
...
@@ -81,7 +81,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
...
@@ -81,7 +81,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
be used as an input to a decoder-only model,
be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure themselves must be
fields of this data structure themselves must be
:class:`SingletonPrompt
Inputs
` instances.
:class:`SingletonPrompt
Type
` instances.
"""
"""
encoder_prompt
:
_T1_co
encoder_prompt
:
_T1_co
...
@@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
...
@@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
decoder_prompt
:
Optional
[
_T2_co
]
decoder_prompt
:
Optional
[
_T2_co
]
Prompt
Inputs
=
Union
[
SingletonPrompt
Inputs
,
ExplicitEncoderDecoderPrompt
]
Prompt
Type
=
Union
[
SingletonPrompt
,
ExplicitEncoderDecoderPrompt
]
"""
"""
Set of possible schemas for an LLM input, including
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
both decoder-only and encoder/decoder input types:
...
@@ -140,12 +140,8 @@ class EncoderDecoderLLMInputs(LLMInputs):
...
@@ -140,12 +140,8 @@ class EncoderDecoderLLMInputs(LLMInputs):
"""
"""
_T1
=
TypeVar
(
"_T1"
,
_T1
=
TypeVar
(
"_T1"
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
)
bound
=
SingletonPromptInputs
,
_T2
=
TypeVar
(
"_T2"
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
)
default
=
SingletonPromptInputs
)
_T2
=
TypeVar
(
"_T2"
,
bound
=
SingletonPromptInputs
,
default
=
SingletonPromptInputs
)
def
build_explicit_enc_dec_prompt
(
def
build_explicit_enc_dec_prompt
(
...
...
vllm/inputs/parse.py
View file @
0057894e
...
@@ -5,7 +5,7 @@ from typing_extensions import TypeIs
...
@@ -5,7 +5,7 @@ from typing_extensions import TypeIs
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
LLMInputs
,
Prompt
Inputs
,
SingletonPrompt
Inputs
,
TextPrompt
,
LLMInputs
,
Prompt
Type
,
SingletonPrompt
,
TextPrompt
,
TokensPrompt
)
TokensPrompt
)
...
@@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict):
...
@@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict):
def
parse_singleton_prompt
(
def
parse_singleton_prompt
(
inputs
:
SingletonPrompt
Inputs
,
prompt
:
SingletonPrompt
,
)
->
Union
[
ParsedStrPrompt
,
ParsedTextPrompt
,
ParsedTokensPrompt
]:
)
->
Union
[
ParsedStrPrompt
,
ParsedTextPrompt
,
ParsedTokensPrompt
]:
if
isinstance
(
inputs
,
str
):
if
isinstance
(
prompt
,
str
):
return
ParsedStrPrompt
(
type
=
"str"
,
content
=
inputs
)
return
ParsedStrPrompt
(
type
=
"str"
,
content
=
prompt
)
elif
isinstance
(
inputs
,
dict
):
elif
isinstance
(
prompt
,
dict
):
if
"prompt_token_ids"
in
inputs
:
if
"prompt_token_ids"
in
prompt
:
return
ParsedTokensPrompt
(
type
=
"tokens"
,
return
ParsedTokensPrompt
(
type
=
"tokens"
,
content
=
inputs
)
# type: ignore
content
=
prompt
)
# type: ignore
elif
"prompt"
in
inputs
:
elif
"prompt"
in
prompt
:
return
ParsedTextPrompt
(
type
=
"text"
,
content
=
inputs
)
return
ParsedTextPrompt
(
type
=
"text"
,
content
=
prompt
)
raise
TypeError
(
"inputs must be a string, TextPrompt, or TokensPrompt"
)
raise
TypeError
(
"inputs must be a string, TextPrompt, or TokensPrompt"
)
def
is_explicit_encoder_decoder_prompt
(
def
is_explicit_encoder_decoder_prompt
(
inputs
:
Prompt
Inputs
)
->
TypeIs
[
ExplicitEncoderDecoderPrompt
]:
prompt
:
Prompt
Type
)
->
TypeIs
[
ExplicitEncoderDecoderPrompt
]:
return
isinstance
(
inputs
,
dict
)
and
"encoder_prompt"
in
inputs
return
isinstance
(
prompt
,
dict
)
and
"encoder_prompt"
in
prompt
def
is_valid_encoder_decoder_llm_inputs
(
def
is_valid_encoder_decoder_llm_inputs
(
...
...
vllm/inputs/preprocess.py
View file @
0057894e
...
@@ -9,8 +9,8 @@ from vllm.lora.request import LoRARequest
...
@@ -9,8 +9,8 @@ from vllm.lora.request import LoRARequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
from
.data
import
(
EncoderDecoderLLMInputs
,
LLMInputs
,
Prompt
Inputs
,
from
.data
import
(
EncoderDecoderLLMInputs
,
LLMInputs
,
Prompt
Type
,
SingletonPrompt
Inputs
)
SingletonPrompt
)
from
.parse
import
is_explicit_encoder_decoder_prompt
,
parse_singleton_prompt
from
.parse
import
is_explicit_encoder_decoder_prompt
,
parse_singleton_prompt
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -206,7 +206,7 @@ class InputPreprocessor:
...
@@ -206,7 +206,7 @@ class InputPreprocessor:
def
_extract_prompt_components
(
def
_extract_prompt_components
(
self
,
self
,
inputs
:
SingletonPrompt
Inputs
,
prompt
:
SingletonPrompt
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
PromptComponents
:
)
->
PromptComponents
:
...
@@ -216,7 +216,7 @@ class InputPreprocessor:
...
@@ -216,7 +216,7 @@ class InputPreprocessor:
Arguments:
Arguments:
* request_id
* request_id
*
inputs
: single encoder or decoder input prompt
*
prompt
: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
* lora_request: this is only valid for decoder prompts
Returns:
Returns:
...
@@ -226,24 +226,24 @@ class InputPreprocessor:
...
@@ -226,24 +226,24 @@ class InputPreprocessor:
* multi_modal_data
* multi_modal_data
'''
'''
parsed
=
parse_singleton_prompt
(
inputs
)
parsed
=
parse_singleton_prompt
(
prompt
)
if
parsed
[
"type"
]
==
"str"
:
if
parsed
[
"type"
]
==
"str"
:
prompt
=
parsed
[
"content"
]
prompt
_text
=
parsed
[
"content"
]
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt
,
prompt
_text
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
multi_modal_data
=
None
multi_modal_data
=
None
elif
parsed
[
"type"
]
==
"tokens"
:
elif
parsed
[
"type"
]
==
"tokens"
:
prompt
=
None
prompt
_text
=
None
prompt_token_ids
=
parsed
[
"content"
][
"prompt_token_ids"
]
prompt_token_ids
=
parsed
[
"content"
][
"prompt_token_ids"
]
multi_modal_data
=
parsed
[
"content"
].
get
(
"multi_modal_data"
)
multi_modal_data
=
parsed
[
"content"
].
get
(
"multi_modal_data"
)
elif
parsed
[
"type"
]
==
"text"
:
elif
parsed
[
"type"
]
==
"text"
:
prompt
=
parsed
[
"content"
][
"prompt"
]
prompt
_text
=
parsed
[
"content"
][
"prompt"
]
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt
,
prompt
_text
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
...
@@ -251,33 +251,33 @@ class InputPreprocessor:
...
@@ -251,33 +251,33 @@ class InputPreprocessor:
else
:
else
:
assert_never
(
parsed
)
assert_never
(
parsed
)
return
prompt
,
prompt_token_ids
,
multi_modal_data
return
prompt
_text
,
prompt_token_ids
,
multi_modal_data
async
def
_extract_prompt_components_async
(
async
def
_extract_prompt_components_async
(
self
,
self
,
inputs
:
SingletonPrompt
Inputs
,
prompt
:
SingletonPrompt
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
PromptComponents
:
)
->
PromptComponents
:
"""Async version of :meth:`_extract_prompt_components`."""
"""Async version of :meth:`_extract_prompt_components`."""
parsed
=
parse_singleton_prompt
(
inputs
)
parsed
=
parse_singleton_prompt
(
prompt
)
if
parsed
[
"type"
]
==
"str"
:
if
parsed
[
"type"
]
==
"str"
:
prompt
=
parsed
[
"content"
]
prompt
_text
=
parsed
[
"content"
]
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt
,
prompt
_text
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
multi_modal_data
=
None
multi_modal_data
=
None
elif
parsed
[
"type"
]
==
"tokens"
:
elif
parsed
[
"type"
]
==
"tokens"
:
prompt
=
None
prompt
_text
=
None
prompt_token_ids
=
parsed
[
"content"
][
"prompt_token_ids"
]
prompt_token_ids
=
parsed
[
"content"
][
"prompt_token_ids"
]
multi_modal_data
=
parsed
[
"content"
].
get
(
"multi_modal_data"
)
multi_modal_data
=
parsed
[
"content"
].
get
(
"multi_modal_data"
)
elif
parsed
[
"type"
]
==
"text"
:
elif
parsed
[
"type"
]
==
"text"
:
prompt
=
parsed
[
"content"
][
"prompt"
]
prompt
_text
=
parsed
[
"content"
][
"prompt"
]
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt
,
prompt
_text
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
...
@@ -285,7 +285,7 @@ class InputPreprocessor:
...
@@ -285,7 +285,7 @@ class InputPreprocessor:
else
:
else
:
assert_never
(
parsed
)
assert_never
(
parsed
)
return
prompt
,
prompt_token_ids
,
multi_modal_data
return
prompt
_text
,
prompt_token_ids
,
multi_modal_data
def
_build_enc_dec_llm_inputs
(
def
_build_enc_dec_llm_inputs
(
self
,
self
,
...
@@ -311,7 +311,7 @@ class InputPreprocessor:
...
@@ -311,7 +311,7 @@ class InputPreprocessor:
def
_process_encoder_decoder_prompt
(
def
_process_encoder_decoder_prompt
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
request_id
:
str
,
request_id
:
str
,
)
->
EncoderDecoderLLMInputs
:
)
->
EncoderDecoderLLMInputs
:
'''
'''
...
@@ -339,7 +339,7 @@ class InputPreprocessor:
...
@@ -339,7 +339,7 @@ class InputPreprocessor:
Arguments:
Arguments:
*
inputs
: an input prompt
*
prompt
: an input prompt
* request_id
* request_id
Returns:
Returns:
...
@@ -350,13 +350,13 @@ class InputPreprocessor:
...
@@ -350,13 +350,13 @@ class InputPreprocessor:
encoder_comps
:
PromptComponents
encoder_comps
:
PromptComponents
decoder_comps
:
DecoderPromptComponents
decoder_comps
:
DecoderPromptComponents
if
is_explicit_encoder_decoder_prompt
(
inputs
):
if
is_explicit_encoder_decoder_prompt
(
prompt
):
encoder_comps
=
self
.
_extract_prompt_components
(
encoder_comps
=
self
.
_extract_prompt_components
(
inputs
[
"encoder_prompt"
],
prompt
[
"encoder_prompt"
],
request_id
=
request_id
,
request_id
=
request_id
,
)
)
if
(
decoder_input
:
=
inputs
[
"decoder_prompt"
])
is
None
:
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
decoder_comps
=
None
,
None
,
None
decoder_comps
=
None
,
None
,
None
else
:
else
:
decoder_comps
=
self
.
_extract_prompt_components
(
decoder_comps
=
self
.
_extract_prompt_components
(
...
@@ -365,7 +365,7 @@ class InputPreprocessor:
...
@@ -365,7 +365,7 @@ class InputPreprocessor:
)
)
else
:
else
:
encoder_comps
=
self
.
_extract_prompt_components
(
encoder_comps
=
self
.
_extract_prompt_components
(
inputs
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
)
)
...
@@ -375,20 +375,20 @@ class InputPreprocessor:
...
@@ -375,20 +375,20 @@ class InputPreprocessor:
async
def
_process_encoder_decoder_prompt_async
(
async
def
_process_encoder_decoder_prompt_async
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
request_id
:
str
,
request_id
:
str
,
)
->
EncoderDecoderLLMInputs
:
)
->
EncoderDecoderLLMInputs
:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps
:
PromptComponents
encoder_comps
:
PromptComponents
decoder_comps
:
DecoderPromptComponents
decoder_comps
:
DecoderPromptComponents
if
is_explicit_encoder_decoder_prompt
(
inputs
):
if
is_explicit_encoder_decoder_prompt
(
prompt
):
encoder_task
=
self
.
_extract_prompt_components_async
(
encoder_task
=
self
.
_extract_prompt_components_async
(
inputs
[
"encoder_prompt"
],
prompt
[
"encoder_prompt"
],
request_id
=
request_id
,
request_id
=
request_id
,
)
)
if
(
decoder_input
:
=
inputs
[
"decoder_prompt"
])
is
None
:
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
encoder_comps
=
await
encoder_task
encoder_comps
=
await
encoder_task
decoder_comps
=
None
,
None
,
None
decoder_comps
=
None
,
None
,
None
else
:
else
:
...
@@ -401,7 +401,7 @@ class InputPreprocessor:
...
@@ -401,7 +401,7 @@ class InputPreprocessor:
encoder_task
,
decoder_task
)
encoder_task
,
decoder_task
)
else
:
else
:
encoder_comps
=
await
self
.
_extract_prompt_components_async
(
encoder_comps
=
await
self
.
_extract_prompt_components_async
(
inputs
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
)
)
...
@@ -425,7 +425,7 @@ class InputPreprocessor:
...
@@ -425,7 +425,7 @@ class InputPreprocessor:
def
_process_decoder_only_prompt
(
def
_process_decoder_only_prompt
(
self
,
self
,
inputs
:
SingletonPrompt
Inputs
,
prompt
:
SingletonPrompt
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
...
@@ -436,7 +436,7 @@ class InputPreprocessor:
...
@@ -436,7 +436,7 @@ class InputPreprocessor:
Arguments:
Arguments:
*
inputs
: input prompt
*
prompt
: input prompt
* request_id
* request_id
* lora_request
* lora_request
* prompt_adapter_request
* prompt_adapter_request
...
@@ -447,7 +447,7 @@ class InputPreprocessor:
...
@@ -447,7 +447,7 @@ class InputPreprocessor:
'''
'''
prompt_comps
=
self
.
_extract_prompt_components
(
prompt_comps
=
self
.
_extract_prompt_components
(
inputs
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
...
@@ -459,14 +459,14 @@ class InputPreprocessor:
...
@@ -459,14 +459,14 @@ class InputPreprocessor:
async
def
_process_decoder_only_prompt_async
(
async
def
_process_decoder_only_prompt_async
(
self
,
self
,
inputs
:
SingletonPrompt
Inputs
,
prompt
:
SingletonPrompt
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
LLMInputs
:
)
->
LLMInputs
:
"""Async version of :meth:`_process_decoder_only_prompt`."""
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps
=
await
self
.
_extract_prompt_components_async
(
prompt_comps
=
await
self
.
_extract_prompt_components_async
(
inputs
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
...
@@ -478,7 +478,7 @@ class InputPreprocessor:
...
@@ -478,7 +478,7 @@ class InputPreprocessor:
def
preprocess
(
def
preprocess
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
...
@@ -488,17 +488,17 @@ class InputPreprocessor:
...
@@ -488,17 +488,17 @@ class InputPreprocessor:
# Encoder-decoder model requires special mapping of
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
# input prompts to encoder & decoder
return
self
.
_process_encoder_decoder_prompt
(
return
self
.
_process_encoder_decoder_prompt
(
inputs
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
)
)
if
is_explicit_encoder_decoder_prompt
(
inputs
):
if
is_explicit_encoder_decoder_prompt
(
prompt
):
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
"to decoder-only models"
)
"to decoder-only models"
)
# Decoder-only operation
# Decoder-only operation
return
self
.
_process_decoder_only_prompt
(
return
self
.
_process_decoder_only_prompt
(
inputs
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
@@ -506,7 +506,7 @@ class InputPreprocessor:
...
@@ -506,7 +506,7 @@ class InputPreprocessor:
async
def
preprocess_async
(
async
def
preprocess_async
(
self
,
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
...
@@ -516,17 +516,17 @@ class InputPreprocessor:
...
@@ -516,17 +516,17 @@ class InputPreprocessor:
# Encoder-decoder model requires special mapping of
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
# input prompts to encoder & decoder
return
await
self
.
_process_encoder_decoder_prompt_async
(
return
await
self
.
_process_encoder_decoder_prompt_async
(
inputs
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
)
)
if
is_explicit_encoder_decoder_prompt
(
inputs
):
if
is_explicit_encoder_decoder_prompt
(
prompt
):
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
"to decoder-only models"
)
"to decoder-only models"
)
# Decoder-only operation
# Decoder-only operation
return
await
self
.
_process_decoder_only_prompt_async
(
return
await
self
.
_process_decoder_only_prompt_async
(
inputs
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment