Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
574fe752
Unverified
Commit
574fe752
authored
Feb 17, 2026
by
Cyrus Leung
Committed by
GitHub
Feb 17, 2026
Browse files
[Renderer] Move InputPreprocessor into Renderer (2/2) (#34560)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
c61a98f5
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
559 additions
and
493 deletions
+559
-493
tests/entrypoints/llm/test_chat.py
tests/entrypoints/llm/test_chat.py
+6
-9
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+3
-2
tests/renderers/test_process_multi_modal_uuids.py
tests/renderers/test_process_multi_modal_uuids.py
+165
-0
tests/samplers/test_beam_search.py
tests/samplers/test_beam_search.py
+0
-2
vllm/beam_search.py
vllm/beam_search.py
+29
-8
vllm/engine/protocol.py
vllm/engine/protocol.py
+6
-6
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+203
-221
vllm/entrypoints/openai/chat_completion/serving.py
vllm/entrypoints/openai/chat_completion/serving.py
+12
-25
vllm/entrypoints/openai/completion/serving.py
vllm/entrypoints/openai/completion/serving.py
+4
-22
vllm/entrypoints/openai/engine/serving.py
vllm/entrypoints/openai/engine/serving.py
+41
-87
vllm/entrypoints/openai/realtime/serving.py
vllm/entrypoints/openai/realtime/serving.py
+8
-2
vllm/entrypoints/openai/responses/context.py
vllm/entrypoints/openai/responses/context.py
+2
-2
vllm/entrypoints/openai/responses/serving.py
vllm/entrypoints/openai/responses/serving.py
+3
-6
vllm/entrypoints/openai/speech_to_text/speech_to_text.py
vllm/entrypoints/openai/speech_to_text/speech_to_text.py
+49
-56
vllm/entrypoints/pooling/embed/serving.py
vllm/entrypoints/pooling/embed/serving.py
+3
-12
vllm/entrypoints/pooling/pooling/serving.py
vllm/entrypoints/pooling/pooling/serving.py
+2
-7
vllm/entrypoints/pooling/score/serving.py
vllm/entrypoints/pooling/score/serving.py
+13
-7
vllm/entrypoints/serve/disagg/serving.py
vllm/entrypoints/serve/disagg/serving.py
+2
-17
vllm/entrypoints/serve/tokenize/serving.py
vllm/entrypoints/serve/tokenize/serving.py
+2
-2
vllm/inputs/data.py
vllm/inputs/data.py
+6
-0
No files found.
tests/entrypoints/llm/test_chat.py
View file @
574fe752
...
@@ -195,18 +195,15 @@ def test_chat_batch_failure_cleanup(llm_for_failure_test):
...
@@ -195,18 +195,15 @@ def test_chat_batch_failure_cleanup(llm_for_failure_test):
valid_msg
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
valid_msg
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
long_text
=
"This is a very long text to test the error "
*
50
long_text
=
"This is a very long text to test the error "
*
50
invalid_msg
=
[{
"role"
:
"user"
,
"content"
:
long_text
}]
invalid_msg
=
[{
"role"
:
"user"
,
"content"
:
long_text
}]
batch_1
=
[
valid_msg
,
batch_1
=
[
valid_msg
,
valid_msg
,
invalid_msg
]
valid_msg
,
batch_2
=
[
valid_msg
,
valid_msg
]
invalid_msg
,
]
batch_2
=
[
valid_msg
,
valid_msg
,
]
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
10
)
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
10
)
with
pytest
.
raises
(
ValueError
,
match
=
"context length is only"
):
with
pytest
.
raises
(
ValueError
,
match
=
"context length is only"
):
llm
.
chat
(
batch_1
,
sampling_params
=
sampling_params
)
llm
.
chat
(
batch_1
,
sampling_params
=
sampling_params
)
assert
llm
.
llm_engine
.
get_num_unfinished_requests
()
==
0
outputs_2
=
llm
.
chat
(
batch_2
,
sampling_params
=
sampling_params
)
outputs_2
=
llm
.
chat
(
batch_2
,
sampling_params
=
sampling_params
)
assert
len
(
outputs_2
)
==
len
(
batch_2
)
assert
len
(
outputs_2
)
==
len
(
batch_2
)
assert
llm
.
llm_engine
.
get_num_unfinished_requests
()
==
0
assert
llm
.
llm_engine
.
get_num_unfinished_requests
()
==
0
tests/models/multimodal/processing/test_common.py
View file @
574fe752
...
@@ -489,8 +489,9 @@ def _assert_inputs_equal(
...
@@ -489,8 +489,9 @@ def _assert_inputs_equal(
if
ignore_mm_keys
is
None
:
if
ignore_mm_keys
is
None
:
ignore_mm_keys
=
set
()
ignore_mm_keys
=
set
()
a_rest
=
{
k
:
v
for
k
,
v
in
a
.
items
()
if
k
!=
"mm_kwargs"
}
ignore_prompt_keys
=
(
"prompt"
,
"mm_kwargs"
)
b_rest
=
{
k
:
v
for
k
,
v
in
b
.
items
()
if
k
!=
"mm_kwargs"
}
a_rest
=
{
k
:
v
for
k
,
v
in
a
.
items
()
if
k
not
in
ignore_prompt_keys
}
b_rest
=
{
k
:
v
for
k
,
v
in
b
.
items
()
if
k
not
in
ignore_prompt_keys
}
assert
a_rest
==
b_rest
,
msg
assert
a_rest
==
b_rest
,
msg
...
...
tests/
v1/engine
/test_process_multi_modal_uuids.py
→
tests/
renderers
/test_process_multi_modal_uuids.py
View file @
574fe752
...
@@ -6,18 +6,17 @@ import pytest
...
@@ -6,18 +6,17 @@ import pytest
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.video
import
VideoAsset
from
vllm.assets.video
import
VideoAsset
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.multimodal
import
MultiModalUUIDDict
from
vllm.renderers.hf
import
HfRenderer
from
vllm.sampling_params
import
SamplingParams
from
vllm.tokenizers.registry
import
tokenizer_args_from_config
from
vllm.v1.engine.input_processor
import
InputProcessor
cherry_pil_image
=
ImageAsset
(
"cherry_blossom"
).
pil_image
cherry_pil_image
=
ImageAsset
(
"cherry_blossom"
).
pil_image
stop_pil_image
=
ImageAsset
(
"stop_sign"
).
pil_image
stop_pil_image
=
ImageAsset
(
"stop_sign"
).
pil_image
baby_reading_np_ndarrays
=
VideoAsset
(
"baby_reading"
).
np_ndarrays
baby_reading_np_ndarrays
=
VideoAsset
(
"baby_reading"
).
np_ndarrays
def
_build_
input_processo
r
(
def
_build_
rendere
r
(
*
,
mm_cache_gb
:
float
=
4.0
,
enable_prefix_caching
:
bool
=
True
*
,
mm_cache_gb
:
float
=
4.0
,
enable_prefix_caching
:
bool
=
True
)
->
InputProcesso
r
:
)
->
HfRendere
r
:
model_config
=
ModelConfig
(
model_config
=
ModelConfig
(
model
=
"Qwen/Qwen2.5-VL-3B-Instruct"
,
model
=
"Qwen/Qwen2.5-VL-3B-Instruct"
,
max_model_len
=
128
,
max_model_len
=
128
,
...
@@ -29,47 +28,45 @@ def _build_input_processor(
...
@@ -29,47 +28,45 @@ def _build_input_processor(
cache_config
=
CacheConfig
(
enable_prefix_caching
=
enable_prefix_caching
),
cache_config
=
CacheConfig
(
enable_prefix_caching
=
enable_prefix_caching
),
)
)
return
InputProcessor
(
vllm_config
)
_
,
tokenizer_name
,
_
,
kwargs
=
tokenizer_args_from_config
(
model_config
)
return
HfRenderer
.
from_config
(
vllm_config
,
tokenizer_kwargs
=
{
**
kwargs
,
"tokenizer_name"
:
tokenizer_name
},
)
def
test_multi_modal_uuids_length_mismatch_raises
():
def
test_multi_modal_uuids_length_mismatch_raises
():
input_processor
=
_build_input_processo
r
()
renderer
=
_build_rendere
r
()
prompt
=
{
mm_data
=
{
"image"
:
[
cherry_pil_image
,
stop_pil_image
]}
"prompt"
:
"USER: <image>
\n
Describe
\n
ASSISTANT:"
,
"multi_modal_data"
:
{
"image"
:
[
cherry_pil_image
,
stop_pil_image
]},
# Mismatch: 2 items but only 1 uuid provided
# Mismatch: 2 items but only 1 uuid provided
mm_uuids
=
{
"image"
:
[
"hash_cherry"
]}
"multi_modal_uuids"
:
{
"image"
:
[
"hash_cherry"
]},
}
mm_processor
=
renderer
.
get_mm_processor
()
mm_items
=
mm_processor
.
info
.
parse_mm_data
(
mm_data
)
with
pytest
.
raises
(
ValueError
,
match
=
"must have same length as"
):
with
pytest
.
raises
(
ValueError
,
match
=
"must have same length as"
):
input_processor
.
process_inputs
(
renderer
.
_process_mm_uuids
(
mm_data
,
mm_items
,
mm_uuids
,
"req-1"
)
request_id
=
"req-1"
,
prompt
=
prompt
,
# type: ignore[arg-type]
params
=
SamplingParams
(),
)
def
test_multi_modal_uuids_missing_modality_raises
():
def
test_multi_modal_uuids_missing_modality_raises
():
input_processor
=
_build_input_processor
()
renderer
=
_build_renderer
()
prompt
=
{
mm_data
=
{
"prompt"
:
"USER: <image><video>
\n
Describe
\n
ASSISTANT:"
,
"image"
:
[
cherry_pil_image
],
# Two modalities provided in data
"video"
:
None
,
"multi_modal_data"
:
{
"image"
:
[
cherry_pil_image
],
"video"
:
None
,
},
# Only image uuids provided; video missing should raise
"multi_modal_uuids"
:
{
"image"
:
[
"hash_cherry"
]},
}
}
# Only image uuids provided; video missing should raise
mm_uuids
=
{
"image"
:
[
"hash_cherry"
]}
mm_processor
=
renderer
.
get_mm_processor
()
mm_items
=
mm_processor
.
info
.
parse_mm_data
(
mm_data
)
with
pytest
.
raises
(
ValueError
,
match
=
"is empty but .* is missing"
):
with
pytest
.
raises
(
ValueError
,
match
=
"is empty but .* is missing"
):
input_processor
.
process_inputs
(
renderer
.
_process_mm_uuids
(
mm_data
,
mm_items
,
mm_uuids
,
"req-2"
)
request_id
=
"req-2"
,
prompt
=
prompt
,
# type: ignore[arg-type]
params
=
SamplingParams
(),
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -83,92 +80,86 @@ def test_multi_modal_uuids_missing_modality_raises():
...
@@ -83,92 +80,86 @@ def test_multi_modal_uuids_missing_modality_raises():
def
test_multi_modal_uuids_accepts_none_and_passes_through
(
def
test_multi_modal_uuids_accepts_none_and_passes_through
(
monkeypatch
,
mm_cache_gb
:
float
,
enable_prefix_caching
:
bool
monkeypatch
,
mm_cache_gb
:
float
,
enable_prefix_caching
:
bool
):
):
input_processor
=
_build_input_processo
r
(
renderer
=
_build_rendere
r
(
mm_cache_gb
=
mm_cache_gb
,
mm_cache_gb
=
mm_cache_gb
,
enable_prefix_caching
=
enable_prefix_caching
,
enable_prefix_caching
=
enable_prefix_caching
,
)
)
# Capture the overrides passed to InputPreprocessor.preprocess
mm_data
=
{
captured
:
dict
[
str
,
object
]
=
{}
"image"
:
[
cherry_pil_image
,
stop_pil_image
],
"video"
:
baby_reading_np_ndarrays
,
def
fake_preprocess
(
}
prompt
,
*
,
tokenization_kwargs
=
None
,
lora_request
=
None
,
mm_uuids
=
None
):
captured
[
"mm_uuids"
]
=
mm_uuids
# Minimal processed inputs for decoder-only flow
return
{
"type"
:
"token"
,
"prompt_token_ids"
:
[
1
]}
# Monkeypatch only the bound preprocess method on this instance
monkeypatch
.
setattr
(
input_processor
.
input_preprocessor
,
"preprocess"
,
fake_preprocess
,
raising
=
True
)
# Use a consistent two-image scenario across all configurations
# Use a consistent two-image scenario across all configurations
mm_uuids
=
{
"image"
:
[
None
,
"hash_stop"
],
"video"
:
None
}
mm_uuids
=
{
"image"
:
[
None
,
"hash_stop"
],
"video"
:
None
}
prompt
=
{
"prompt"
:
"USER: <image><image>
\n
Two images
\n
ASSISTANT:"
,
"multi_modal_data"
:
{
"image"
:
[
cherry_pil_image
,
stop_pil_image
],
"video"
:
baby_reading_np_ndarrays
,
},
"multi_modal_uuids"
:
mm_uuids
,
}
input
_processor
.
process_inputs
(
mm
_processor
=
renderer
.
get_mm_processor
()
request_id
=
"req-3"
,
mm_items
=
mm_processor
.
info
.
parse_mm_data
(
mm_data
)
pro
mpt
=
prompt
,
# type: ignore[arg-type]
pro
cessed_mm_uuids
=
renderer
.
_process_mm_uuids
(
params
=
SamplingParams
(),
mm_data
,
mm_items
,
mm_uuids
,
"req-3"
)
)
assert
captured
[
"
mm_uuids
"
]
==
mm_uuids
assert
processed_
mm_uuids
==
mm_uuids
def
test_multi_modal_uuids_ignored_when_caching_disabled
(
monkeypatch
):
@
pytest
.
mark
.
parametrize
(
# When both processor cache is 0 and prefix caching disabled, the
"mm_cache_gb, enable_prefix_caching"
,
# processor builds overrides from request id instead of using user UUIDs.
[
input_processor
=
_build_input_processor
(
(
4.0
,
True
),
# default behavior
mm_cache_gb
=
0.0
,
enable_prefix_caching
=
False
(
4.0
,
False
),
# prefix caching disabled
(
0.0
,
True
),
# processor cache disabled
],
)
def
test_multi_modal_uuids_accepts_empty
(
monkeypatch
,
mm_cache_gb
:
float
,
enable_prefix_caching
:
bool
):
renderer
=
_build_renderer
(
mm_cache_gb
=
mm_cache_gb
,
enable_prefix_caching
=
enable_prefix_caching
,
)
)
captured
:
dict
[
str
,
MultiModalUUIDDict
]
=
{}
# While None means cached multi-modal input requiring UUIDs
# an empty list means no multi-modal input
mm_data
=
{
"image"
:
[],
"video"
:
[]}
# type: ignore[var-annotated]
mm_uuids
=
{
"image"
:
[],
"video"
:
None
}
# type: ignore[var-annotated]
def
fake_preprocess
(
mm_processor
=
renderer
.
get_mm_processor
()
prompt
,
*
,
tokenization_kwargs
=
None
,
lora_request
=
None
,
mm_uuids
=
None
mm_items
=
mm_processor
.
info
.
parse_mm_data
(
mm_data
)
):
processed_mm_uuids
=
renderer
.
_process_mm_uuids
(
captured
[
"mm_uuids"
]
=
mm_uuids
mm_data
,
mm_items
,
mm_uuids
,
"req-4"
return
{
"type"
:
"token"
,
"prompt_token_ids"
:
[
1
]}
monkeypatch
.
setattr
(
input_processor
.
input_preprocessor
,
"preprocess"
,
fake_preprocess
,
raising
=
True
)
)
assert
processed_mm_uuids
==
mm_uuids
def
test_multi_modal_uuids_ignored_when_caching_disabled
(
monkeypatch
):
# When both processor cache is 0 and prefix caching disabled, the
# processor builds overrides from request id instead of using user UUIDs.
renderer
=
_build_renderer
(
mm_cache_gb
=
0.0
,
enable_prefix_caching
=
False
)
request_id
=
"req-42"
request_id
=
"req-42"
mm_uuids
=
{
"image"
:
[
"hash_cherry"
,
"hash_stop"
],
"video"
:
[
"hash_video"
]}
mm_data
=
{
prompt
=
{
"image"
:
[
cherry_pil_image
,
stop_pil_image
],
"prompt"
:
"USER: <image><image><video>
\n
Describe
\n
ASSISTANT:"
,
"video"
:
baby_reading_np_ndarrays
,
"multi_modal_data"
:
{
"image"
:
[
cherry_pil_image
,
stop_pil_image
],
"video"
:
[
baby_reading_np_ndarrays
],
},
"multi_modal_uuids"
:
mm_uuids
,
}
}
mm_uuids
=
{
"image"
:
[
"hash_cherry"
,
"hash_stop"
],
"video"
:
[
"hash_video"
]}
input
_processor
.
process_inputs
(
mm
_processor
=
renderer
.
get_mm_processor
()
request_id
=
request_id
,
mm_items
=
mm_processor
.
info
.
parse_mm_data
(
mm_data
)
pro
mpt
=
prompt
,
# type: ignore[arg-type]
pro
cessed_mm_uuids
=
renderer
.
_process_mm_uuids
(
params
=
SamplingParams
(),
mm_data
,
mm_items
,
mm_uuids
,
request_id
)
)
# Expect request-id-based overrides are passed through
# Expect request-id-based overrides are passed through
assert
set
(
mm_uuids
.
keys
())
==
{
"image"
,
"video"
}
assert
set
(
mm_uuids
.
keys
())
==
{
"image"
,
"video"
}
assert
len
(
mm_uuids
[
"image"
])
==
2
assert
len
(
mm_uuids
[
"image"
])
==
2
assert
len
(
mm_uuids
[
"video"
])
==
1
assert
len
(
mm_uuids
[
"video"
])
==
1
assert
captured
[
"
mm_uuids
"
]
[
"image"
][
0
].
startswith
(
assert
processed_
mm_uuids
[
"image"
][
0
].
startswith
(
f
"
{
request_id
}
-image-"
f
"
{
request_id
}
-image-"
)
and
captured
[
"
mm_uuids
"
]
[
"image"
][
0
].
endswith
(
"-0"
)
)
and
processed_
mm_uuids
[
"image"
][
0
].
endswith
(
"-0"
)
assert
captured
[
"
mm_uuids
"
]
[
"image"
][
1
].
startswith
(
assert
processed_
mm_uuids
[
"image"
][
1
].
startswith
(
f
"
{
request_id
}
-image-"
f
"
{
request_id
}
-image-"
)
and
captured
[
"
mm_uuids
"
]
[
"image"
][
1
].
endswith
(
"-1"
)
)
and
processed_
mm_uuids
[
"image"
][
1
].
endswith
(
"-1"
)
assert
captured
[
"
mm_uuids
"
]
[
"video"
][
0
].
startswith
(
assert
processed_
mm_uuids
[
"video"
][
0
].
startswith
(
f
"
{
request_id
}
-video-"
f
"
{
request_id
}
-video-"
)
and
captured
[
"
mm_uuids
"
]
[
"video"
][
0
].
endswith
(
"-0"
)
)
and
processed_
mm_uuids
[
"video"
][
0
].
endswith
(
"-0"
)
tests/samplers/test_beam_search.py
View file @
574fe752
...
@@ -20,7 +20,6 @@ MM_BEAM_WIDTHS = [2]
...
@@ -20,7 +20,6 @@ MM_BEAM_WIDTHS = [2]
MODELS
=
[
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
]
MODELS
=
[
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
]
@
pytest
.
mark
.
skip_v1
# V1 engine does not yet support beam search
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
MAX_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
MAX_TOKENS
)
...
@@ -62,7 +61,6 @@ def test_beam_search_single_input(
...
@@ -62,7 +61,6 @@ def test_beam_search_single_input(
)
)
@
pytest
.
mark
.
skip_v1
# V1 engine does not yet support beam search
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
MAX_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
MAX_TOKENS
)
...
...
vllm/beam_search.py
View file @
574fe752
...
@@ -2,13 +2,11 @@
...
@@ -2,13 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
from
vllm.inputs
import
TokenInputs
,
token_inputs
from
vllm.logprobs
import
Logprob
from
vllm.logprobs
import
Logprob
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal.inputs
import
MultiModalInputs
,
mm_inputs
if
TYPE_CHECKING
:
from
vllm.multimodal
import
MultiModalDataDict
@
dataclass
@
dataclass
...
@@ -19,6 +17,8 @@ class BeamSearchSequence:
...
@@ -19,6 +17,8 @@ class BeamSearchSequence:
about to be returned to the user.
about to be returned to the user.
"""
"""
orig_prompt
:
TokenInputs
|
MultiModalInputs
# The tokens include the prompt.
# The tokens include the prompt.
tokens
:
list
[
int
]
tokens
:
list
[
int
]
logprobs
:
list
[
dict
[
int
,
Logprob
]]
logprobs
:
list
[
dict
[
int
,
Logprob
]]
...
@@ -27,8 +27,28 @@ class BeamSearchSequence:
...
@@ -27,8 +27,28 @@ class BeamSearchSequence:
text
:
str
|
None
=
None
text
:
str
|
None
=
None
finish_reason
:
str
|
None
=
None
finish_reason
:
str
|
None
=
None
stop_reason
:
int
|
str
|
None
=
None
stop_reason
:
int
|
str
|
None
=
None
multi_modal_data
:
"MultiModalDataDict | None"
=
None
mm_processor_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
def
get_prompt
(
self
):
prompt
=
self
.
orig_prompt
prompt_text
=
prompt
.
get
(
"prompt"
)
cache_salt
=
prompt
.
get
(
"cache_salt"
)
if
prompt
[
"type"
]
==
"token"
:
return
token_inputs
(
self
.
tokens
,
prompt
=
prompt_text
,
cache_salt
=
cache_salt
,
)
return
mm_inputs
(
prompt_token_ids
=
self
.
tokens
,
mm_kwargs
=
prompt
[
"mm_kwargs"
],
mm_hashes
=
prompt
[
"mm_hashes"
],
mm_placeholders
=
prompt
[
"mm_placeholders"
],
prompt
=
prompt_text
,
cache_salt
=
cache_salt
,
)
@
dataclass
@
dataclass
...
@@ -44,14 +64,15 @@ class BeamSearchOutput:
...
@@ -44,14 +64,15 @@ class BeamSearchOutput:
class
BeamSearchInstance
:
class
BeamSearchInstance
:
def
__init__
(
def
__init__
(
self
,
self
,
prompt
_t
oken
s
:
list
[
int
]
,
prompt
:
T
oken
Inputs
|
MultiModalInputs
,
lora_request
:
LoRARequest
|
None
=
None
,
lora_request
:
LoRARequest
|
None
=
None
,
logprobs
:
list
[
dict
[
int
,
Logprob
]]
|
None
=
None
,
logprobs
:
list
[
dict
[
int
,
Logprob
]]
|
None
=
None
,
**
kwargs
,
**
kwargs
,
):
):
self
.
beams
:
list
[
BeamSearchSequence
]
=
[
self
.
beams
:
list
[
BeamSearchSequence
]
=
[
BeamSearchSequence
(
BeamSearchSequence
(
tokens
=
prompt_tokens
,
orig_prompt
=
prompt
,
tokens
=
prompt
[
"prompt_token_ids"
],
logprobs
=
[]
if
logprobs
is
None
else
list
(
logprobs
),
logprobs
=
[]
if
logprobs
is
None
else
list
(
logprobs
),
lora_request
=
lora_request
,
lora_request
=
lora_request
,
**
kwargs
,
**
kwargs
,
...
...
vllm/engine/protocol.py
View file @
574fe752
...
@@ -11,13 +11,12 @@ from vllm.distributed.weight_transfer.base import (
...
@@ -11,13 +11,12 @@ from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest
,
WeightTransferInitRequest
,
WeightTransferUpdateRequest
,
WeightTransferUpdateRequest
,
)
)
from
vllm.inputs.data
import
PromptType
from
vllm.inputs.data
import
ProcessorInputs
,
PromptType
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.plugins.io_processors
import
IOProcessor
from
vllm.plugins.io_processors
import
IOProcessor
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.renderers
import
BaseRenderer
from
vllm.renderers
import
BaseRenderer
from
vllm.renderers.inputs
import
DictPrompt
,
TokPrompt
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.tasks
import
SupportedTask
from
vllm.tasks
import
SupportedTask
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine
import
EngineCoreRequest
...
@@ -35,7 +34,7 @@ class StreamingInput:
...
@@ -35,7 +34,7 @@ class StreamingInput:
where inputs are provided via an async generator.
where inputs are provided via an async generator.
"""
"""
prompt
:
Pro
mptType
prompt
:
Pro
cessorInputs
sampling_params
:
SamplingParams
|
None
=
None
sampling_params
:
SamplingParams
|
None
=
None
...
@@ -69,8 +68,7 @@ class EngineClient(ABC):
...
@@ -69,8 +68,7 @@ class EngineClient(ABC):
self
,
self
,
prompt
:
EngineCoreRequest
prompt
:
EngineCoreRequest
|
PromptType
|
PromptType
|
DictPrompt
|
ProcessorInputs
|
TokPrompt
|
AsyncGenerator
[
StreamingInput
,
None
],
|
AsyncGenerator
[
StreamingInput
,
None
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
request_id
:
str
,
...
@@ -81,6 +79,7 @@ class EngineClient(ABC):
...
@@ -81,6 +79,7 @@ class EngineClient(ABC):
trace_headers
:
Mapping
[
str
,
str
]
|
None
=
None
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
data_parallel_rank
:
int
|
None
=
None
,
data_parallel_rank
:
int
|
None
=
None
,
reasoning_ended
:
bool
|
None
=
None
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate outputs for a request."""
"""Generate outputs for a request."""
...
...
...
@@ -88,13 +87,14 @@ class EngineClient(ABC):
...
@@ -88,13 +87,14 @@ class EngineClient(ABC):
@
abstractmethod
@
abstractmethod
def
encode
(
def
encode
(
self
,
self
,
prompt
:
PromptType
|
DictPrompt
|
TokPrompt
,
prompt
:
PromptType
|
ProcessorInputs
,
pooling_params
:
PoolingParams
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
LoRARequest
|
None
=
None
,
lora_request
:
LoRARequest
|
None
=
None
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
=
None
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
reasoning_ended
:
bool
|
None
=
None
,
)
->
AsyncGenerator
[
PoolingRequestOutput
,
None
]:
)
->
AsyncGenerator
[
PoolingRequestOutput
,
None
]:
"""Generate outputs for a request from a pooling model."""
"""Generate outputs for a request from a pooling model."""
...
...
...
...
vllm/entrypoints/llm.py
View file @
574fe752
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
import
itertools
import
itertools
import
warnings
import
warnings
from
collections.abc
import
Callable
,
Sequence
from
collections.abc
import
Callable
,
Iterable
,
Sequence
from
typing
import
TYPE_CHECKING
,
Any
,
cast
from
typing
import
TYPE_CHECKING
,
Any
import
cloudpickle
import
cloudpickle
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -55,6 +55,7 @@ from vllm.entrypoints.pooling.score.utils import (
...
@@ -55,6 +55,7 @@ from vllm.entrypoints.pooling.score.utils import (
from
vllm.entrypoints.utils
import
log_non_default_args
from
vllm.entrypoints.utils
import
log_non_default_args
from
vllm.inputs.data
import
(
from
vllm.inputs.data
import
(
DataPrompt
,
DataPrompt
,
ProcessorInputs
,
PromptType
,
PromptType
,
SingletonPrompt
,
SingletonPrompt
,
TextPrompt
,
TextPrompt
,
...
@@ -73,10 +74,8 @@ from vllm.outputs import (
...
@@ -73,10 +74,8 @@ from vllm.outputs import (
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.renderers
import
ChatParams
,
merge_kwargs
from
vllm.renderers
import
ChatParams
,
merge_kwargs
from
vllm.renderers.inputs
import
DictPrompt
,
TokPrompt
from
vllm.renderers.inputs.preprocess
import
(
from
vllm.renderers.inputs.preprocess
import
(
conversation_to_seq
,
conversation_to_seq
,
extract_prompt_components
,
parse_model_prompt
,
parse_model_prompt
,
prompt_to_seq
,
prompt_to_seq
,
)
)
...
@@ -86,6 +85,7 @@ from vllm.tokenizers import TokenizerLike
...
@@ -86,6 +85,7 @@ from vllm.tokenizers import TokenizerLike
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils.counter
import
Counter
from
vllm.utils.counter
import
Counter
from
vllm.utils.tqdm_utils
import
maybe_tqdm
from
vllm.v1.engine.llm_engine
import
LLMEngine
from
vllm.v1.engine.llm_engine
import
LLMEngine
from
vllm.v1.sample.logits_processor
import
LogitsProcessor
from
vllm.v1.sample.logits_processor
import
LogitsProcessor
...
@@ -400,7 +400,7 @@ class LLM:
...
@@ -400,7 +400,7 @@ class LLM:
sampling_params
:
SamplingParams
|
Sequence
[
SamplingParams
]
|
None
=
None
,
sampling_params
:
SamplingParams
|
Sequence
[
SamplingParams
]
|
None
=
None
,
*
,
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
lora_request
:
list
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
lora_request
:
Sequence
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
priority
:
list
[
int
]
|
None
=
None
,
priority
:
list
[
int
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
list
[
RequestOutput
]:
)
->
list
[
RequestOutput
]:
...
@@ -462,7 +462,7 @@ class LLM:
...
@@ -462,7 +462,7 @@ class LLM:
self
,
self
,
prompts
:
PromptType
|
Sequence
[
PromptType
],
prompts
:
PromptType
|
Sequence
[
PromptType
],
sampling_params
:
SamplingParams
|
Sequence
[
SamplingParams
]
|
None
=
None
,
sampling_params
:
SamplingParams
|
Sequence
[
SamplingParams
]
|
None
=
None
,
lora_request
:
list
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
lora_request
:
Sequence
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
priority
:
list
[
int
]
|
None
=
None
,
priority
:
list
[
int
]
|
None
=
None
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
...
@@ -495,34 +495,32 @@ class LLM:
...
@@ -495,34 +495,32 @@ class LLM:
# Use the same preprocessing as _run_completion
# Use the same preprocessing as _run_completion
seq_prompts
=
prompt_to_seq
(
prompts
)
seq_prompts
=
prompt_to_seq
(
prompts
)
seq_params
=
self
.
_params_to_seq
(
sampling_params
,
len
(
seq_prompts
))
seq_params
=
self
.
_params_to_seq
(
sampling_params
,
len
(
seq_prompts
))
seq_lora_requests
=
self
.
_lora_request_to_seq
(
lora_request
,
len
(
seq_prompts
))
if
any
(
param
.
truncate_prompt_tokens
is
not
None
for
param
in
seq_params
):
seq_tok_kwargs
=
[
engine_prompts
:
Sequence
[
DictPrompt
|
TokPrompt
]
=
[
merge_kwargs
(
engine_prompt
tokenization_kwargs
,
for
prompt
,
param
in
zip
(
seq_prompts
,
seq_params
)
dict
(
truncate_prompt_tokens
=
param
.
truncate_prompt_tokens
),
for
engine_prompt
in
self
.
_preprocess_cmpl
(
)
[
prompt
],
for
param
in
seq_params
tokenization_kwargs
=
merge_kwargs
(
]
tokenization_kwargs
,
seq_priority
=
self
.
_priority_to_seq
(
priority
,
len
(
prompts
))
dict
(
truncate_prompt_tokens
=
param
.
truncate_prompt_tokens
),
request_ids
=
self
.
_render_and_add_requests
(
prompts
=
(
self
.
_preprocess_cmpl_one
(
prompt
,
tok_kwargs
)
for
prompt
,
tok_kwargs
in
zip
(
maybe_tqdm
(
seq_prompts
,
use_tqdm
=
use_tqdm
,
desc
=
"Rendering prompts"
,
),
),
seq_tok_kwargs
,
)
)
]
else
:
engine_prompts
=
self
.
_preprocess_cmpl
(
seq_prompts
,
tokenization_kwargs
=
tokenization_kwargs
,
)
request_ids
=
self
.
_validate_and_add_requests
(
prompts
=
engine_prompts
,
params
=
seq_params
,
use_tqdm
=
use_tqdm
,
lora_request
=
self
.
_get_modality_specific_lora_reqs
(
engine_prompts
,
lora_request
),
),
params
=
seq_params
,
lora_requests
=
seq_lora_requests
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
priorit
y
=
priority
,
priorit
ies
=
seq_
priority
,
)
)
return
request_ids
return
request_ids
...
@@ -545,53 +543,41 @@ class LLM:
...
@@ -545,53 +543,41 @@ class LLM:
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
def
_
get_modality_specific
_lora_reqs
(
def
_
resolve
_lora_reqs
(
self
,
self
,
prompts
:
Sequence
[
DictPrompt
|
TokPrompt
],
prompts
:
Sequence
[
ProcessorInputs
],
lora_request
:
list
[
LoRARequest
]
|
LoRARequest
|
None
,
lora_request
:
Sequence
[
LoRARequest
|
None
]
|
LoRARequest
|
None
,
):
):
# Grab the lora config off the vllm config on the engine,
# since this is the same for both v0 & v1.
lora_config
=
self
.
llm_engine
.
vllm_config
.
lora_config
lora_config
=
self
.
llm_engine
.
vllm_config
.
lora_config
seq_lora_requests
=
self
.
_lora_request_to_seq
(
lora_request
,
len
(
prompts
))
# If there's no lora config / default_mm_loras, or the model
# isn't multimodal, leave the lora as is.
if
(
if
(
lora_config
is
None
lora_config
is
None
or
not
self
.
model_config
.
is_multimodal_model
or
not
self
.
model_config
.
is_multimodal_model
or
(
lora_config
and
lora_config
.
default_mm_loras
is
None
)
or
(
lora_config
and
lora_config
.
default_mm_loras
is
None
)
):
):
return
lora_request
return
seq_lora_requests
optional_loras
=
(
[
lora_request
]
*
len
(
prompts
)
if
not
isinstance
(
lora_request
,
Sequence
)
else
lora_request
)
return
[
return
[
self
.
_resolve_single_prompt_mm_lora
(
self
.
_resolve_single_prompt_mm_lora
(
prompt
,
prompt
,
opt_
lora_req
,
lora_req
,
lora_config
.
default_mm_loras
,
lora_config
.
default_mm_loras
,
)
)
for
prompt
,
opt_
lora_req
in
zip
(
prompts
,
optional_lora
s
)
for
prompt
,
lora_req
in
zip
(
prompts
,
seq_lora_request
s
)
]
]
def
_resolve_single_prompt_mm_lora
(
def
_resolve_single_prompt_mm_lora
(
self
,
self
,
prompt
:
DictPrompt
|
TokPrompt
,
prompt
:
ProcessorInputs
,
lora_request
:
LoRARequest
|
None
,
lora_request
:
LoRARequest
|
None
,
default_mm_loras
:
dict
[
str
,
str
]
|
None
,
default_mm_loras
:
dict
[
str
,
str
]
|
None
,
):
):
if
not
default_mm_loras
or
not
(
if
not
default_mm_loras
or
prompt
[
"type"
]
!=
"multimodal"
:
mm_data
:
=
prompt
.
get
(
"multi_modal_data"
)
or
{}
):
return
lora_request
return
lora_request
intersection
=
set
(
prompt_modalities
=
prompt
[
"mm_placeholders"
].
keys
()
mm_data
.
keys
()
# type: ignore
intersection
=
set
(
prompt_modalities
).
intersection
(
default_mm_loras
.
keys
())
).
intersection
(
default_mm_loras
.
keys
())
if
not
intersection
:
if
not
intersection
:
return
lora_request
return
lora_request
if
len
(
intersection
)
>
1
:
if
len
(
intersection
)
>
1
:
...
@@ -674,22 +660,6 @@ class LLM:
...
@@ -674,22 +660,6 @@ class LLM:
"""
"""
return
self
.
llm_engine
.
apply_model
(
func
)
return
self
.
llm_engine
.
apply_model
(
func
)
def
_get_beam_search_lora_requests
(
self
,
lora_request
:
list
[
LoRARequest
]
|
LoRARequest
|
None
,
prompts
:
list
[
TokensPrompt
|
TextPrompt
],
)
->
list
[
LoRARequest
|
None
]:
"""Get the optional lora request corresponding to each prompt."""
if
isinstance
(
lora_request
,
Sequence
)
and
len
(
lora_request
)
!=
len
(
prompts
):
raise
ValueError
(
"Lora request list should be the same length as the prompts"
)
if
lora_request
is
None
or
isinstance
(
lora_request
,
LoRARequest
):
return
[
lora_request
]
*
len
(
prompts
)
raise
TypeError
(
f
"Invalid lora_request type
{
type
(
lora_request
)
}
"
)
def
beam_search
(
def
beam_search
(
self
,
self
,
prompts
:
list
[
TokensPrompt
|
TextPrompt
],
prompts
:
list
[
TokensPrompt
|
TextPrompt
],
...
@@ -718,13 +688,12 @@ class LLM:
...
@@ -718,13 +688,12 @@ class LLM:
ignore_eos
=
params
.
ignore_eos
ignore_eos
=
params
.
ignore_eos
length_penalty
=
params
.
length_penalty
length_penalty
=
params
.
length_penalty
lora_requests
=
self
.
_get_beam_search_lora_requests
(
lora_request
,
prompts
)
tokenizer
=
self
.
renderer
.
get_tokenizer
()
eos_token_id
=
tokenizer
.
eos_token_id
sort_beams_key
=
create_sort_beams_key_function
(
eos_token_id
,
length_penalty
)
tokenizer
=
self
.
get_tokenizer
()
engine_prompts
=
self
.
_preprocess_cmpl
(
prompts
)
sort_beams_key
=
create_sort_beams_key_function
(
lora_requests
=
self
.
_lora_request_to_seq
(
lora_request
,
len
(
engine_prompts
))
tokenizer
.
eos_token_id
,
length_penalty
,
)
if
use_tqdm
and
concurrency_limit
is
not
None
:
if
use_tqdm
and
concurrency_limit
is
not
None
:
logger
.
warning
(
logger
.
warning
(
...
@@ -734,21 +703,12 @@ class LLM:
...
@@ -734,21 +703,12 @@ class LLM:
use_tqdm
=
False
use_tqdm
=
False
if
concurrency_limit
is
None
:
if
concurrency_limit
is
None
:
concurrency_limit
=
len
(
prompts
)
concurrency_limit
=
len
(
engine_prompts
)
def
create_tokens_prompt_from_beam
(
beam
:
BeamSearchSequence
)
->
TokensPrompt
:
token_prompt_kwargs
:
TokensPrompt
=
{
"prompt_token_ids"
:
beam
.
tokens
}
if
beam
.
multi_modal_data
is
not
None
:
token_prompt_kwargs
[
"multi_modal_data"
]
=
beam
.
multi_modal_data
if
beam
.
mm_processor_kwargs
is
not
None
:
token_prompt_kwargs
[
"mm_processor_kwargs"
]
=
beam
.
mm_processor_kwargs
return
TokensPrompt
(
**
token_prompt_kwargs
)
# generate 2 * beam_width candidates at each step
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search
_params
=
SamplingParams
(
sampling
_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
max_tokens
=
1
,
temperature
=
temperature
,
temperature
=
temperature
,
...
@@ -756,30 +716,25 @@ class LLM:
...
@@ -756,30 +716,25 @@ class LLM:
)
)
instances
:
list
[
BeamSearchInstance
]
=
[]
instances
:
list
[
BeamSearchInstance
]
=
[]
for
lora_req
,
prompt
in
zip
(
lora_requests
,
prompts
):
for
lora_req
,
prompt
in
zip
(
lora_requests
,
engine_prompts
):
# Add multimodal processor kwargs & data
if
prompt
[
"type"
]
==
"embeds"
:
mm_kwargs
=
{}
raise
NotImplementedError
(
if
"multi_modal_data"
in
prompt
:
"Embedding prompt not supported for beam search"
mm_kwargs
[
"multi_modal_data"
]
=
prompt
[
"multi_modal_data"
]
)
if
"mm_processor_kwargs"
in
prompt
:
if
prompt
[
"type"
]
==
"enc_dec"
:
mm_kwargs
[
"mm_processor_kwargs"
]
=
prompt
[
"mm_processor_kwargs"
]
raise
NotImplementedError
(
"Encoder-decoder prompt not supported for beam search"
if
"prompt_token_ids"
in
prompt
:
)
prompt
=
cast
(
TokensPrompt
,
prompt
)
# Needed for mypy
prompt_tokens
=
prompt
[
"prompt_token_ids"
]
else
:
prompt_tokens
=
tokenizer
.
encode
(
prompt
[
"prompt"
])
instances
.
append
(
instances
.
append
(
BeamSearchInstance
(
BeamSearchInstance
(
prompt
_tokens
,
prompt
,
lora_request
=
lora_req
,
lora_request
=
lora_req
,
logprobs
=
None
,
logprobs
=
None
,
**
mm_kwargs
,
),
),
)
)
for
prompt_start
in
range
(
0
,
len
(
prompt
s
),
concurrency_limit
):
for
prompt_start
in
range
(
0
,
len
(
instance
s
),
concurrency_limit
):
instances_batch
=
instances
[
prompt_start
:
prompt_start
+
concurrency_limit
]
instances_batch
=
instances
[
prompt_start
:
prompt_start
+
concurrency_limit
]
token_iter
=
range
(
max_tokens
)
token_iter
=
range
(
max_tokens
)
...
@@ -808,22 +763,15 @@ class LLM:
...
@@ -808,22 +763,15 @@ class LLM:
if
len
(
all_beams
)
==
0
:
if
len
(
all_beams
)
==
0
:
break
break
# create corresponding batch entries for prompt & optional lora
prompts_batch
,
lora_req_batch
=
zip
(
*
[
(
create_tokens_prompt_from_beam
(
beam
),
beam
.
lora_request
)
for
beam
in
all_beams
]
)
# only runs for one step
# only runs for one step
# we don't need to use tqdm here
# we don't need to use tqdm here
output
=
self
.
generate
(
raw_output
=
self
.
_render_and_run_requests
(
prompts_batch
,
prompts
=
(
beam
.
get_prompt
()
for
beam
in
all_beams
),
sampling_params
=
beam_search_params
,
params
=
self
.
_params_to_seq
(
sampling_params
,
len
(
all_beams
)),
lora_requests
=
[
beam
.
lora_request
for
beam
in
all_beams
],
use_tqdm
=
False
,
use_tqdm
=
False
,
lora_request
=
lora_req_batch
,
)
)
output
=
self
.
engine_class
.
validate_outputs
(
raw_output
,
RequestOutput
)
for
(
start
,
end
),
instance
in
zip
(
for
(
start
,
end
),
instance
in
zip
(
instance_start_and_end
,
instances_batch
instance_start_and_end
,
instances_batch
...
@@ -841,19 +789,15 @@ class LLM:
...
@@ -841,19 +789,15 @@ class LLM:
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
for
token_id
,
logprob_obj
in
logprobs
.
items
():
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
new_beam
=
BeamSearchSequence
(
current_beam
.
orig_prompt
,
tokens
=
current_beam
.
tokens
+
[
token_id
],
tokens
=
current_beam
.
tokens
+
[
token_id
],
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
lora_request
=
current_beam
.
lora_request
,
lora_request
=
current_beam
.
lora_request
,
cum_logprob
=
current_beam
.
cum_logprob
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
,
+
logprob_obj
.
logprob
,
multi_modal_data
=
current_beam
.
multi_modal_data
,
mm_processor_kwargs
=
current_beam
.
mm_processor_kwargs
,
)
)
if
(
if
token_id
==
eos_token_id
and
not
ignore_eos
:
token_id
==
tokenizer
.
eos_token_id
and
not
ignore_eos
):
instance
.
completed
.
append
(
new_beam
)
instance
.
completed
.
append
(
new_beam
)
else
:
else
:
instance_new_beams
.
append
(
new_beam
)
instance_new_beams
.
append
(
new_beam
)
...
@@ -872,6 +816,7 @@ class LLM:
...
@@ -872,6 +816,7 @@ class LLM:
for
beam
in
best_beams
:
for
beam
in
best_beams
:
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
)
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
)
outputs
.
append
(
BeamSearchOutput
(
sequences
=
best_beams
))
outputs
.
append
(
BeamSearchOutput
(
sequences
=
best_beams
))
return
outputs
return
outputs
...
@@ -880,7 +825,7 @@ class LLM:
...
@@ -880,7 +825,7 @@ class LLM:
self
,
self
,
prompts
:
Sequence
[
PromptType
],
prompts
:
Sequence
[
PromptType
],
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
Sequence
[
DictPrompt
|
TokPrompt
]:
)
->
Sequence
[
ProcessorInputs
]:
"""
"""
Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
a format that can be passed to `_add_request`.
a format that can be passed to `_add_request`.
...
@@ -888,8 +833,7 @@ class LLM:
...
@@ -888,8 +833,7 @@ class LLM:
Refer to [LLM.generate][] for a complete description of the arguments.
Refer to [LLM.generate][] for a complete description of the arguments.
Returns:
Returns:
A list of `TokPrompt` objects containing the tokenized prompt
A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
after chat template interpolation, and the raw multi-modal inputs.
"""
"""
renderer
=
self
.
renderer
renderer
=
self
.
renderer
model_config
=
self
.
model_config
model_config
=
self
.
model_config
...
@@ -903,6 +847,14 @@ class LLM:
...
@@ -903,6 +847,14 @@ class LLM:
return
renderer
.
render_cmpl
(
parsed_prompts
,
tok_params
)
return
renderer
.
render_cmpl
(
parsed_prompts
,
tok_params
)
def
_preprocess_cmpl_one
(
self
,
prompt
:
PromptType
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
ProcessorInputs
:
(
engine_prompt
,)
=
self
.
_preprocess_cmpl
([
prompt
],
tokenization_kwargs
)
return
engine_prompt
def
_preprocess_chat
(
def
_preprocess_chat
(
self
,
self
,
conversations
:
Sequence
[
list
[
ChatCompletionMessageParam
]],
conversations
:
Sequence
[
list
[
ChatCompletionMessageParam
]],
...
@@ -914,7 +866,7 @@ class LLM:
...
@@ -914,7 +866,7 @@ class LLM:
tools
:
list
[
dict
[
str
,
Any
]]
|
None
=
None
,
tools
:
list
[
dict
[
str
,
Any
]]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
mm_processor_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
mm_processor_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
Sequence
[
TokPrompt
]:
)
->
Sequence
[
ProcessorInputs
]:
"""
"""
Convert a list of conversations into prompts so that they can then
Convert a list of conversations into prompts so that they can then
be used as input for other LLM APIs.
be used as input for other LLM APIs.
...
@@ -922,8 +874,7 @@ class LLM:
...
@@ -922,8 +874,7 @@ class LLM:
Refer to [LLM.chat][] for a complete description of the arguments.
Refer to [LLM.chat][] for a complete description of the arguments.
Returns:
Returns:
A list of `TokPrompt` objects containing the tokenized prompt
A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
after chat template interpolation, and the raw multi-modal inputs.
"""
"""
renderer
=
self
.
renderer
renderer
=
self
.
renderer
...
@@ -953,13 +904,39 @@ class LLM:
...
@@ -953,13 +904,39 @@ class LLM:
return
engine_prompts
return
engine_prompts
def
_preprocess_chat_one
(
self
,
conversation
:
list
[
ChatCompletionMessageParam
],
chat_template
:
str
|
None
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
chat_template_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
add_generation_prompt
:
bool
=
True
,
continue_final_message
:
bool
=
False
,
tools
:
list
[
dict
[
str
,
Any
]]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
mm_processor_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
ProcessorInputs
:
(
engine_prompt
,)
=
self
.
_preprocess_chat
(
[
conversation
],
chat_template
=
chat_template
,
chat_template_content_format
=
chat_template_content_format
,
chat_template_kwargs
=
chat_template_kwargs
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
continue_final_message
,
tools
=
tools
,
tokenization_kwargs
=
tokenization_kwargs
,
mm_processor_kwargs
=
mm_processor_kwargs
,
)
return
engine_prompt
def
chat
(
def
chat
(
self
,
self
,
messages
:
list
[
ChatCompletionMessageParam
]
messages
:
list
[
ChatCompletionMessageParam
]
|
Sequence
[
list
[
ChatCompletionMessageParam
]],
|
Sequence
[
list
[
ChatCompletionMessageParam
]],
sampling_params
:
SamplingParams
|
Sequence
[
SamplingParams
]
|
None
=
None
,
sampling_params
:
SamplingParams
|
Sequence
[
SamplingParams
]
|
None
=
None
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
lora_request
:
LoRARequest
|
None
=
None
,
lora_request
:
Sequence
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
chat_template
:
str
|
None
=
None
,
chat_template
:
str
|
None
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
add_generation_prompt
:
bool
=
True
,
add_generation_prompt
:
bool
=
True
,
...
@@ -1805,47 +1782,41 @@ class LLM:
...
@@ -1805,47 +1782,41 @@ class LLM:
|
Sequence
[
SamplingParams
|
PoolingParams
],
|
Sequence
[
SamplingParams
|
PoolingParams
],
*
,
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
lora_request
:
list
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
lora_request
:
Sequence
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
priority
:
list
[
int
]
|
None
=
None
,
priority
:
list
[
int
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
):
seq_prompts
=
prompt_to_seq
(
prompts
)
seq_prompts
=
prompt_to_seq
(
prompts
)
seq_params
=
self
.
_params_to_seq
(
params
,
len
(
seq_prompts
))
seq_params
=
self
.
_params_to_seq
(
params
,
len
(
seq_prompts
))
seq_lora_requests
=
self
.
_lora_request_to_seq
(
lora_request
,
len
(
seq_prompts
))
if
any
(
param
.
truncate_prompt_tokens
is
not
None
for
param
in
seq_params
):
seq_tok_kwargs
=
[
# TODO: Remove this after deprecating `param.truncate_prompt_tokens`
merge_kwargs
(
# Then, move the code from the `else` block to the top and let
tokenization_kwargs
,
# `self._preprocess_cmpl` handle prompt normalization
dict
(
truncate_prompt_tokens
=
param
.
truncate_prompt_tokens
),
engine_prompts
:
Sequence
[
DictPrompt
|
TokPrompt
]
=
[
)
engine_prompt
for
param
in
seq_params
for
prompt
,
param
in
zip
(
seq_prompts
,
seq_params
)
]
for
engine_prompt
in
self
.
_preprocess_cmpl
(
seq_priority
=
self
.
_priority_to_seq
(
priority
,
len
(
prompts
))
[
prompt
],
tokenization_kwargs
=
merge_kwargs
(
return
self
.
_render_and_run_requests
(
tokenization_kwargs
,
prompts
=
(
dict
(
truncate_prompt_tokens
=
param
.
truncate_prompt_tokens
),
self
.
_preprocess_cmpl_one
(
prompt
,
tok_kwargs
)
for
prompt
,
tok_kwargs
in
zip
(
maybe_tqdm
(
seq_prompts
,
use_tqdm
=
use_tqdm
,
desc
=
"Rendering prompts"
,
),
),
seq_tok_kwargs
,
)
)
]
),
else
:
engine_prompts
=
self
.
_preprocess_cmpl
(
seq_prompts
,
tokenization_kwargs
=
tokenization_kwargs
,
)
self
.
_validate_and_add_requests
(
prompts
=
engine_prompts
,
params
=
seq_params
,
params
=
seq_params
,
use_tqdm
=
use_tqdm
,
use_tqdm
=
use_tqdm
,
lora_request
=
self
.
_get_modality_specific_lora_reqs
(
lora_requests
=
seq_lora_requests
,
engine_prompts
,
lora_request
),
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
priorit
y
=
priority
,
priorit
ies
=
seq_
priority
,
)
)
return
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
def
_run_chat
(
def
_run_chat
(
self
,
self
,
messages
:
list
[
ChatCompletionMessageParam
]
messages
:
list
[
ChatCompletionMessageParam
]
...
@@ -1855,7 +1826,7 @@ class LLM:
...
@@ -1855,7 +1826,7 @@ class LLM:
|
Sequence
[
SamplingParams
|
PoolingParams
],
|
Sequence
[
SamplingParams
|
PoolingParams
],
*
,
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
lora_request
:
LoRARequest
|
None
=
None
,
lora_request
:
Sequence
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
chat_template
:
str
|
None
=
None
,
chat_template
:
str
|
None
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
add_generation_prompt
:
bool
=
True
,
add_generation_prompt
:
bool
=
True
,
...
@@ -1865,68 +1836,94 @@ class LLM:
...
@@ -1865,68 +1836,94 @@ class LLM:
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
mm_processor_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
mm_processor_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
):
engine_prompts
=
self
.
_preprocess_chat
(
seq_convs
=
conversation_to_seq
(
messages
)
conversation_to_seq
(
messages
),
seq_params
=
self
.
_params_to_seq
(
params
,
len
(
seq_convs
))
chat_template
=
chat_template
,
seq_lora_requests
=
self
.
_lora_request_to_seq
(
lora_request
,
len
(
seq_convs
))
chat_template_content_format
=
chat_template_content_format
,
seq_tok_kwargs
=
[
chat_template_kwargs
=
chat_template_kwargs
,
merge_kwargs
(
add_generation_prompt
=
add_generation_prompt
,
tokenization_kwargs
,
continue_final_message
=
continue_final_message
,
dict
(
truncate_prompt_tokens
=
param
.
truncate_prompt_tokens
),
tools
=
tools
,
)
for
param
in
seq_params
]
return
self
.
_render_and_run_requests
(
prompts
=
(
self
.
_preprocess_chat_one
(
conversation
,
chat_template
=
chat_template
,
chat_template_content_format
=
chat_template_content_format
,
chat_template_kwargs
=
chat_template_kwargs
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
continue_final_message
,
tools
=
tools
,
tokenization_kwargs
=
tok_kwargs
,
mm_processor_kwargs
=
mm_processor_kwargs
,
)
for
conversation
,
tok_kwargs
in
zip
(
maybe_tqdm
(
seq_convs
,
use_tqdm
=
use_tqdm
,
desc
=
"Rendering conversations"
,
),
seq_tok_kwargs
,
)
),
params
=
seq_params
,
lora_requests
=
seq_lora_requests
,
use_tqdm
=
use_tqdm
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
mm_processor_kwargs
=
mm_processor_kwargs
,
)
)
self
.
_validate_and_add_requests
(
def
_render_and_run_requests
(
prompts
=
engine_prompts
,
self
,
prompts
:
Iterable
[
ProcessorInputs
],
params
:
Sequence
[
SamplingParams
|
PoolingParams
],
*
,
lora_requests
:
Sequence
[
LoRARequest
|
None
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
priorities
:
Sequence
[
int
]
|
None
=
None
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
):
if
isinstance
(
prompts
,
(
list
,
tuple
)):
logger
.
warning_once
(
"Rendering all prompts before adding them to the engine "
"is less efficient than performing both on the same prompt "
"before processing the next prompt. You should instead pass "
"a generator that renders one prompt per iteration, as that allows "
"engine execution to begin for the first prompt while processing "
"the next prompt."
)
self
.
_render_and_add_requests
(
prompts
=
prompts
,
params
=
params
,
params
=
params
,
use_tqdm
=
use_tqdm
,
lora_requests
=
lora_requests
,
lora_request
=
self
.
_get_modality_specific_lora_reqs
(
engine_prompts
,
lora_request
),
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
priorities
=
priorities
,
)
)
return
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
def
_
validate
_and_add_requests
(
def
_
render
_and_add_requests
(
self
,
self
,
prompts
:
Sequence
[
DictPrompt
|
TokPrompt
],
prompts
:
Iterable
[
ProcessorInputs
],
params
:
SamplingParams
params
:
Sequence
[
SamplingParams
|
PoolingParams
],
|
PoolingParams
|
Sequence
[
SamplingParams
|
PoolingParams
],
*
,
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
lora_requests
:
Sequence
[
LoRARequest
|
None
]
|
None
=
None
,
lora_request
:
Sequence
[
LoRARequest
|
None
]
|
LoRARequest
|
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
priorit
y
:
list
[
int
]
|
None
=
None
,
priorit
ies
:
Sequence
[
int
]
|
None
=
None
,
)
->
list
[
str
]:
)
->
list
[
str
]:
num_requests
=
len
(
prompts
)
seq_params
=
self
.
_params_to_seq
(
params
,
num_requests
)
seq_lora_requests
=
self
.
_lora_request_to_seq
(
lora_request
,
num_requests
)
seq_priority
=
self
.
_priority_to_seq
(
priority
,
num_requests
)
for
sp
in
seq_params
:
if
isinstance
(
sp
,
SamplingParams
):
# We only care about the final output
sp
.
output_kind
=
RequestOutputKind
.
FINAL_ONLY
# Add requests to the engine.
it
=
prompts
if
use_tqdm
:
tqdm_func
=
use_tqdm
if
callable
(
use_tqdm
)
else
tqdm
it
=
tqdm_func
(
it
,
desc
=
"Adding requests"
)
added_request_ids
:
list
[
str
]
=
[]
added_request_ids
:
list
[
str
]
=
[]
try
:
try
:
for
i
,
prompt
in
enumerate
(
it
):
for
i
,
prompt
in
enumerate
(
prompts
):
request_id
=
self
.
_add_request
(
request_id
=
self
.
_add_request
(
prompt
,
prompt
,
seq_
params
[
i
],
params
[
i
],
lora_request
=
seq_
lora_requests
[
i
],
lora_request
=
None
if
lora_requests
is
None
else
lora_requests
[
i
],
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
priority
=
seq_
priorit
y
[
i
],
priority
=
0
if
priorities
is
None
else
priorit
ies
[
i
],
)
)
added_request_ids
.
append
(
request_id
)
added_request_ids
.
append
(
request_id
)
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -1938,13 +1935,16 @@ class LLM:
...
@@ -1938,13 +1935,16 @@ class LLM:
def
_add_request
(
def
_add_request
(
self
,
self
,
prompt
:
Pro
mptType
|
DictPrompt
|
TokPrompt
,
prompt
:
Pro
cessorInputs
,
params
:
SamplingParams
|
PoolingParams
,
params
:
SamplingParams
|
PoolingParams
,
lora_request
:
LoRARequest
|
None
=
None
,
lora_request
:
LoRARequest
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
str
:
)
->
str
:
prompt_text
,
_
,
_
=
extract_prompt_components
(
self
.
model_config
,
prompt
)
if
isinstance
(
params
,
SamplingParams
):
# We only care about the final output
params
.
output_kind
=
RequestOutputKind
.
FINAL_ONLY
request_id
=
str
(
next
(
self
.
request_counter
))
request_id
=
str
(
next
(
self
.
request_counter
))
if
params
.
truncate_prompt_tokens
is
not
None
:
if
params
.
truncate_prompt_tokens
is
not
None
:
...
@@ -1962,32 +1962,14 @@ class LLM:
...
@@ -1962,32 +1962,14 @@ class LLM:
dict
(
truncate_prompt_tokens
=
params
.
truncate_prompt_tokens
),
dict
(
truncate_prompt_tokens
=
params
.
truncate_prompt_tokens
),
)
)
renderer
=
self
.
renderer
return
self
.
llm_engine
.
add_request
(
tok_params
=
renderer
.
default_cmpl_tok_params
.
with_kwargs
(
**
(
tokenization_kwargs
or
{})
)
tokenization_kwargs
=
tok_params
.
get_encode_kwargs
()
engine_request
=
self
.
input_processor
.
process_inputs
(
request_id
,
request_id
,
prompt
,
prompt
,
params
,
params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
priority
=
priority
,
priority
=
priority
,
supported_tasks
=
self
.
supported_tasks
,
)
self
.
llm_engine
.
add_request
(
request_id
,
engine_request
,
params
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
priority
=
priority
,
prompt_text
=
prompt_text
,
)
)
return
engine_request
.
request_id
def
_run_engine
(
def
_run_engine
(
self
,
self
,
...
...
vllm/entrypoints/openai/chat_completion/serving.py
View file @
574fe752
...
@@ -67,13 +67,12 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
...
@@ -67,13 +67,12 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
)
)
from
vllm.entrypoints.openai.utils
import
maybe_filter_parallel_tool_calls
from
vllm.entrypoints.openai.utils
import
maybe_filter_parallel_tool_calls
from
vllm.entrypoints.utils
import
get_max_tokens
,
should_include_usage
from
vllm.entrypoints.utils
import
get_max_tokens
,
should_include_usage
from
vllm.inputs.data
import
TokensPrompt
from
vllm.inputs.data
import
ProcessorInputs
,
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
from
vllm.logprobs
import
Logprob
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.parser
import
ParserManager
from
vllm.parser
import
ParserManager
from
vllm.reasoning
import
ReasoningParser
from
vllm.reasoning
import
ReasoningParser
from
vllm.renderers.inputs
import
TokPrompt
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers.mistral
import
(
from
vllm.tokenizers.mistral
import
(
...
@@ -221,7 +220,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -221,7 +220,7 @@ class OpenAIServingChat(OpenAIServing):
async
def
render_chat_request
(
async
def
render_chat_request
(
self
,
self
,
request
:
ChatCompletionRequest
,
request
:
ChatCompletionRequest
,
)
->
tuple
[
list
[
ConversationMessage
],
list
[
TokPrompt
]]
|
ErrorResponse
:
)
->
tuple
[
list
[
ConversationMessage
],
list
[
ProcessorInputs
]]
|
ErrorResponse
:
"""
"""
render chat request by validating and preprocessing inputs.
render chat request by validating and preprocessing inputs.
...
@@ -380,7 +379,9 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -380,7 +379,9 @@ class OpenAIServingChat(OpenAIServing):
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
try
:
try
:
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
prompt_text
=
self
.
_extract_prompt_text
(
engine_prompt
)
prompt_token_ids
=
self
.
_extract_prompt_components
(
engine_prompt
).
token_ids
# If we are creating sub requests for multiple prompts, ensure that they
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
# have unique request ids.
...
@@ -431,35 +432,21 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -431,35 +432,21 @@ class OpenAIServingChat(OpenAIServing):
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
)
)
else
:
else
:
tok_params
=
request
.
build_tok_params
(
self
.
model_config
)
reasoning_ended
=
(
tokenization_kwargs
=
tok_params
.
get_encode_kwargs
()
reasoning_parser
.
is_reasoning_end
(
prompt_token_ids
or
[])
if
reasoning_parser
engine_request
=
self
.
input_processor
.
process_inputs
(
else
None
sub_request_id
,
engine_prompt
,
sampling_params
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
data_parallel_rank
=
data_parallel_rank
,
)
)
reasoning_ended
=
None
if
reasoning_parser
:
reasoning_ended
=
reasoning_parser
.
is_reasoning_end
(
engine_request
.
prompt_token_ids
or
[]
# type: ignore[attr-defined]
)
engine_request
.
reasoning_ended
=
reasoning_ended
generator
=
self
.
engine_client
.
generate
(
generator
=
self
.
engine_client
.
generate
(
engine_
reques
t
,
engine_
promp
t
,
sampling_params
,
sampling_params
,
sub_request_id
,
sub_request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
priority
=
request
.
priority
,
prompt_text
=
prompt_text
,
tokenization_kwargs
=
tokenization_kwargs
,
data_parallel_rank
=
data_parallel_rank
,
data_parallel_rank
=
data_parallel_rank
,
reasoning_ended
=
reasoning_ended
,
)
)
generators
.
append
(
generator
)
generators
.
append
(
generator
)
...
...
vllm/entrypoints/openai/completion/serving.py
View file @
574fe752
...
@@ -34,10 +34,10 @@ from vllm.entrypoints.openai.engine.serving import (
...
@@ -34,10 +34,10 @@ from vllm.entrypoints.openai.engine.serving import (
from
vllm.entrypoints.openai.models.serving
import
OpenAIServingModels
from
vllm.entrypoints.openai.models.serving
import
OpenAIServingModels
from
vllm.entrypoints.utils
import
get_max_tokens
,
should_include_usage
from
vllm.entrypoints.utils
import
get_max_tokens
,
should_include_usage
from
vllm.exceptions
import
VLLMValidationError
from
vllm.exceptions
import
VLLMValidationError
from
vllm.inputs.data
import
ProcessorInputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
from
vllm.logprobs
import
Logprob
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.renderers.inputs
import
TokPrompt
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.utils.async_utils
import
merge_async_iterators
from
vllm.utils.async_utils
import
merge_async_iterators
...
@@ -80,7 +80,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -80,7 +80,7 @@ class OpenAIServingCompletion(OpenAIServing):
async
def
render_completion_request
(
async
def
render_completion_request
(
self
,
self
,
request
:
CompletionRequest
,
request
:
CompletionRequest
,
)
->
list
[
TokPrompt
]
|
ErrorResponse
:
)
->
list
[
ProcessorInputs
]
|
ErrorResponse
:
"""
"""
render completion request by validating and preprocessing inputs.
render completion request by validating and preprocessing inputs.
...
@@ -163,8 +163,6 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -163,8 +163,6 @@ class OpenAIServingCompletion(OpenAIServing):
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
try
:
try
:
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
prompt_text
=
self
.
_extract_prompt_text
(
engine_prompt
)
max_tokens
=
get_max_tokens
(
max_tokens
=
get_max_tokens
(
max_model_len
,
max_model_len
,
request
.
max_tokens
,
request
.
max_tokens
,
...
@@ -208,29 +206,13 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -208,29 +206,13 @@ class OpenAIServingCompletion(OpenAIServing):
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
)
)
else
:
else
:
tok_params
=
request
.
build_tok_params
(
self
.
model_config
)
tokenization_kwargs
=
tok_params
.
get_encode_kwargs
()
engine_request
=
self
.
input_processor
.
process_inputs
(
request_id_item
,
engine_prompt
,
sampling_params
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
data_parallel_rank
=
data_parallel_rank
,
)
generator
=
self
.
engine_client
.
generate
(
generator
=
self
.
engine_client
.
generate
(
engine_
reques
t
,
engine_
promp
t
,
sampling_params
,
sampling_params
,
request_id_item
,
request_id_item
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
priority
=
request
.
priority
,
prompt_text
=
prompt_text
,
tokenization_kwargs
=
tokenization_kwargs
,
data_parallel_rank
=
data_parallel_rank
,
data_parallel_rank
=
data_parallel_rank
,
)
)
...
@@ -312,7 +294,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -312,7 +294,7 @@ class OpenAIServingCompletion(OpenAIServing):
async
def
completion_stream_generator
(
async
def
completion_stream_generator
(
self
,
self
,
request
:
CompletionRequest
,
request
:
CompletionRequest
,
engine_prompts
:
list
[
TokPrompt
],
engine_prompts
:
list
[
ProcessorInputs
],
result_generator
:
AsyncIterator
[
tuple
[
int
,
RequestOutput
]],
result_generator
:
AsyncIterator
[
tuple
[
int
,
RequestOutput
]],
request_id
:
str
,
request_id
:
str
,
created_time
:
int
,
created_time
:
int
,
...
...
vllm/entrypoints/openai/engine/serving.py
View file @
574fe752
...
@@ -96,15 +96,19 @@ from vllm.entrypoints.serve.tokenize.protocol import (
...
@@ -96,15 +96,19 @@ from vllm.entrypoints.serve.tokenize.protocol import (
)
)
from
vllm.entrypoints.utils
import
get_max_tokens
,
sanitize_message
from
vllm.entrypoints.utils
import
get_max_tokens
,
sanitize_message
from
vllm.exceptions
import
VLLMValidationError
from
vllm.exceptions
import
VLLMValidationError
from
vllm.inputs.data
import
PromptType
,
SingletonPrompt
,
TokensPrompt
from
vllm.inputs.data
import
(
ProcessorInputs
,
PromptType
,
SingletonPrompt
,
TokensPrompt
,
token_inputs
,
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
,
PromptLogprobs
from
vllm.logprobs
import
Logprob
,
PromptLogprobs
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.outputs
import
CompletionOutput
,
PoolingRequestOutput
,
RequestOutput
from
vllm.outputs
import
CompletionOutput
,
PoolingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.renderers
import
ChatParams
,
TokenizeParams
,
merge_kwargs
from
vllm.renderers
import
ChatParams
,
TokenizeParams
,
merge_kwargs
from
vllm.renderers.inputs
import
TokPrompt
from
vllm.renderers.inputs.preprocess
import
(
from
vllm.renderers.inputs.preprocess
import
(
extract_prompt_components
,
extract_prompt_components
,
extract_prompt_len
,
extract_prompt_len
,
...
@@ -206,7 +210,7 @@ class ServeContext(Generic[RequestT]):
...
@@ -206,7 +210,7 @@ class ServeContext(Generic[RequestT]):
request_id
:
str
request_id
:
str
created_time
:
int
=
field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
created_time
:
int
=
field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
lora_request
:
LoRARequest
|
None
=
None
lora_request
:
LoRARequest
|
None
=
None
engine_prompts
:
list
[
TokPrompt
]
|
None
=
None
engine_prompts
:
list
[
ProcessorInputs
]
|
None
=
None
result_generator
:
AsyncGenerator
[
tuple
[
int
,
PoolingRequestOutput
],
None
]
|
None
=
(
result_generator
:
AsyncGenerator
[
tuple
[
int
,
PoolingRequestOutput
],
None
]
|
None
=
(
None
None
...
@@ -249,7 +253,7 @@ class OpenAIServing:
...
@@ -249,7 +253,7 @@ class OpenAIServing:
async
def
beam_search
(
async
def
beam_search
(
self
,
self
,
prompt
:
TokPrompt
,
prompt
:
ProcessorInputs
,
request_id
:
str
,
request_id
:
str
,
params
:
BeamSearchParams
,
params
:
BeamSearchParams
,
lora_request
:
LoRARequest
|
None
=
None
,
lora_request
:
LoRARequest
|
None
=
None
,
...
@@ -262,86 +266,53 @@ class OpenAIServing:
...
@@ -262,86 +266,53 @@ class OpenAIServing:
length_penalty
=
params
.
length_penalty
length_penalty
=
params
.
length_penalty
include_stop_str_in_output
=
params
.
include_stop_str_in_output
include_stop_str_in_output
=
params
.
include_stop_str_in_output
input_processor
=
self
.
input_processor
tokenizer
=
self
.
renderer
.
get_tokenizer
()
tokenizer
=
input_processor
.
tokenizer
eos_token_id
=
tokenizer
.
eos_token_id
if
tokenizer
is
None
:
sort_beams_key
=
create_sort_beams_key_function
(
eos_token_id
,
length_penalty
)
raise
VLLMValidationError
(
"You cannot use beam search when `skip_tokenizer_init=True`"
,
parameter
=
"skip_tokenizer_init"
,
value
=
True
,
)
eos_token_id
:
int
=
tokenizer
.
eos_token_id
# type: ignore
if
isinstance
(
prompt
,
dict
)
and
"encoder_prompt"
in
prompt
:
raise
NotImplementedError
(
"Encoder-decoder prompt not supported"
)
prompt_text
:
str
|
None
=
prompt
.
get
(
"prompt"
)
# type: ignore
prompt_token_ids
:
list
[
int
]
=
prompt
.
get
(
"prompt_token_ids"
,
[])
# type: ignore
multi_modal_data
:
MultiModalDataDict
|
None
=
prompt
.
get
(
"multi_modal_data"
)
# type: ignore
mm_processor_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
# This is a workaround to fix multimodal beam search; this is a
if
prompt
[
"type"
]
==
"embeds"
:
# bandaid fix for 2 small problems:
raise
NotImplementedError
(
"Embedding prompt not supported for beam search"
)
# 1. Multi_modal_data on the processed_inputs currently resolves to
if
prompt
[
"type"
]
==
"enc_dec"
:
# `None`.
raise
NotImplementedError
(
# 2. preprocessing above expands the multimodal placeholders. However,
"Encoder-decoder prompt not supported for beam search"
# this happens again in generation, so the double expansion causes
)
# a mismatch.
# TODO - would be ideal to handle this more gracefully.
prompt_text
=
prompt
.
get
(
"prompt"
)
prompt_token_ids
=
prompt
[
"prompt_token_ids"
]
tokenized_length
=
len
(
prompt_token_ids
)
tokenized_length
=
len
(
prompt_token_ids
)
sort_beams_key
=
create_sort_beams_key_function
(
eos_token_id
,
length_penalty
)
logprobs_num
=
2
*
beam_width
logprobs_num
=
2
*
beam_width
beam_search
_params
=
SamplingParams
(
sampling
_params
=
SamplingParams
(
logprobs
=
logprobs_num
,
logprobs
=
logprobs_num
,
max_tokens
=
1
,
max_tokens
=
1
,
temperature
=
temperature
,
temperature
=
temperature
,
)
)
all_beams
=
[
all_beams
=
[
BeamSearchSequence
(
BeamSearchSequence
(
orig_prompt
=
prompt
,
tokens
=
prompt_token_ids
,
tokens
=
prompt_token_ids
,
cum_logprob
=
0
,
cum_logprob
=
0
,
logprobs
=
[],
logprobs
=
[],
multi_modal_data
=
multi_modal_data
,
mm_processor_kwargs
=
mm_processor_kwargs
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
]
]
completed
=
[]
completed
=
[]
for
_
in
range
(
max_tokens
):
for
_
in
range
(
max_tokens
):
prompts_batch
,
lora_req_batch
=
zip
(
*
[
(
TokensPrompt
(
prompt_token_ids
=
beam
.
tokens
,
multi_modal_data
=
beam
.
multi_modal_data
,
mm_processor_kwargs
=
beam
.
mm_processor_kwargs
,
),
beam
.
lora_request
,
)
for
beam
in
all_beams
]
)
tasks
=
[]
tasks
=
[]
request_id_batch
=
f
"
{
request_id
}
-
{
random_uuid
()
}
"
request_id_batch
=
f
"
{
request_id
}
-
{
random_uuid
()
}
"
for
i
,
(
individual_prompt
,
lora_req
)
in
enumerate
(
for
i
,
beam
in
enumerate
(
all_beams
):
zip
(
prompt
s_batch
,
lora_req_batch
)
prompt
_item
=
beam
.
get_prompt
(
)
):
lora_request_item
=
beam
.
lora_request
request_id_item
=
f
"
{
request_id_batch
}
-beam-
{
i
}
"
request_id_item
=
f
"
{
request_id_batch
}
-beam-
{
i
}
"
task
=
asyncio
.
create_task
(
task
=
asyncio
.
create_task
(
collect_from_async_generator
(
collect_from_async_generator
(
self
.
engine_client
.
generate
(
self
.
engine_client
.
generate
(
individual_
prompt
,
prompt
_item
,
beam_search
_params
,
sampling
_params
,
request_id_item
,
request_id_item
,
lora_request
=
lora_req
,
lora_request
=
lora_req
uest_item
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
)
)
)
)
...
@@ -406,6 +377,7 @@ class OpenAIServing:
...
@@ -406,6 +377,7 @@ class OpenAIServing:
logprobs_entry
=
result
.
outputs
[
0
].
logprobs
[
0
]
logprobs_entry
=
result
.
outputs
[
0
].
logprobs
[
0
]
completed
.
append
(
completed
.
append
(
BeamSearchSequence
(
BeamSearchSequence
(
orig_prompt
=
prompt
,
tokens
=
current_beam
.
tokens
+
[
eos_token_id
]
tokens
=
current_beam
.
tokens
+
[
eos_token_id
]
if
include_stop_str_in_output
if
include_stop_str_in_output
else
current_beam
.
tokens
,
else
current_beam
.
tokens
,
...
@@ -433,12 +405,11 @@ class OpenAIServing:
...
@@ -433,12 +405,11 @@ class OpenAIServing:
logprobs_entry
=
result
.
outputs
[
0
].
logprobs
[
0
]
logprobs_entry
=
result
.
outputs
[
0
].
logprobs
[
0
]
new_beams
.
append
(
new_beams
.
append
(
BeamSearchSequence
(
BeamSearchSequence
(
orig_prompt
=
prompt
,
tokens
=
current_beam
.
tokens
+
[
token_id
],
tokens
=
current_beam
.
tokens
+
[
token_id
],
logprobs
=
current_beam
.
logprobs
+
[
logprobs_entry
],
logprobs
=
current_beam
.
logprobs
+
[
logprobs_entry
],
lora_request
=
current_beam
.
lora_request
,
lora_request
=
current_beam
.
lora_request
,
cum_logprob
=
float
(
all_beams_logprob
[
idx
]),
cum_logprob
=
float
(
all_beams_logprob
[
idx
]),
multi_modal_data
=
current_beam
.
multi_modal_data
,
mm_processor_kwargs
=
current_beam
.
mm_processor_kwargs
,
)
)
)
)
...
@@ -958,7 +929,7 @@ class OpenAIServing:
...
@@ -958,7 +929,7 @@ class OpenAIServing:
request
:
RendererRequest
,
request
:
RendererRequest
,
prompt_input
:
str
|
list
[
str
]
|
list
[
int
]
|
list
[
list
[
int
]]
|
None
,
prompt_input
:
str
|
list
[
str
]
|
list
[
int
]
|
list
[
list
[
int
]]
|
None
,
prompt_embeds
:
bytes
|
list
[
bytes
]
|
None
,
prompt_embeds
:
bytes
|
list
[
bytes
]
|
None
,
)
->
list
[
TokPrompt
]:
)
->
list
[
ProcessorInputs
]:
prompts
=
list
[
SingletonPrompt
|
bytes
]()
prompts
=
list
[
SingletonPrompt
|
bytes
]()
if
prompt_embeds
is
not
None
:
# embeds take higher priority
if
prompt_embeds
is
not
None
:
# embeds take higher priority
prompts
.
extend
(
prompt_to_seq
(
prompt_embeds
))
prompts
.
extend
(
prompt_to_seq
(
prompt_embeds
))
...
@@ -971,7 +942,7 @@ class OpenAIServing:
...
@@ -971,7 +942,7 @@ class OpenAIServing:
self
,
self
,
request
:
RendererRequest
,
request
:
RendererRequest
,
prompts
:
Sequence
[
PromptType
|
bytes
],
prompts
:
Sequence
[
PromptType
|
bytes
],
)
->
list
[
TokPrompt
]:
)
->
list
[
ProcessorInputs
]:
renderer
=
self
.
renderer
renderer
=
self
.
renderer
model_config
=
self
.
model_config
model_config
=
self
.
model_config
...
@@ -1004,7 +975,7 @@ class OpenAIServing:
...
@@ -1004,7 +975,7 @@ class OpenAIServing:
default_template_kwargs
:
dict
[
str
,
Any
]
|
None
,
default_template_kwargs
:
dict
[
str
,
Any
]
|
None
,
tool_dicts
:
list
[
dict
[
str
,
Any
]]
|
None
=
None
,
tool_dicts
:
list
[
dict
[
str
,
Any
]]
|
None
=
None
,
tool_parser
:
Callable
[[
TokenizerLike
],
ToolParser
]
|
None
=
None
,
tool_parser
:
Callable
[[
TokenizerLike
],
ToolParser
]
|
None
=
None
,
)
->
tuple
[
list
[
ConversationMessage
],
list
[
TokPrompt
]]:
)
->
tuple
[
list
[
ConversationMessage
],
list
[
ProcessorInputs
]]:
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
vllm.tokenizers.mistral
import
MistralTokenizer
renderer
=
self
.
renderer
renderer
=
self
.
renderer
...
@@ -1052,13 +1023,13 @@ class OpenAIServing:
...
@@ -1052,13 +1023,13 @@ class OpenAIServing:
return
conversation
,
[
engine_prompt
]
return
conversation
,
[
engine_prompt
]
def
_extract_prompt_components
(
self
,
prompt
:
object
):
def
_extract_prompt_components
(
self
,
prompt
:
PromptType
|
ProcessorInputs
):
return
extract_prompt_components
(
self
.
model_config
,
prompt
)
return
extract_prompt_components
(
self
.
model_config
,
prompt
)
def
_extract_prompt_text
(
self
,
prompt
:
object
):
def
_extract_prompt_text
(
self
,
prompt
:
ProcessorInputs
):
return
self
.
_extract_prompt_components
(
prompt
).
text
return
self
.
_extract_prompt_components
(
prompt
).
text
def
_extract_prompt_len
(
self
,
prompt
:
object
):
def
_extract_prompt_len
(
self
,
prompt
:
ProcessorInputs
):
return
extract_prompt_len
(
self
.
model_config
,
prompt
)
return
extract_prompt_len
(
self
.
model_config
,
prompt
)
async
def
_render_next_turn
(
async
def
_render_next_turn
(
...
@@ -1088,16 +1059,14 @@ class OpenAIServing:
...
@@ -1088,16 +1059,14 @@ class OpenAIServing:
async
def
_generate_with_builtin_tools
(
async
def
_generate_with_builtin_tools
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
engine_prompt
:
TokPrompt
,
engine_prompt
:
ProcessorInputs
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
tok_params
:
TokenizeParams
,
context
:
ConversationContext
,
context
:
ConversationContext
,
lora_request
:
LoRARequest
|
None
=
None
,
lora_request
:
LoRARequest
|
None
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
=
None
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
=
None
,
):
):
max_model_len
=
self
.
model_config
.
max_model_len
max_model_len
=
self
.
model_config
.
max_model_len
prompt_text
=
self
.
_extract_prompt_text
(
engine_prompt
)
orig_priority
=
priority
orig_priority
=
priority
sub_request
=
0
sub_request
=
0
...
@@ -1112,26 +1081,13 @@ class OpenAIServing:
...
@@ -1112,26 +1081,13 @@ class OpenAIServing:
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
tokenization_kwargs
=
tok_params
.
get_encode_kwargs
()
engine_request
=
self
.
input_processor
.
process_inputs
(
sub_request_id
,
engine_prompt
,
sampling_params
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
trace_headers
=
trace_headers
,
priority
=
priority
,
)
generator
=
self
.
engine_client
.
generate
(
generator
=
self
.
engine_client
.
generate
(
engine_
reques
t
,
engine_
promp
t
,
sampling_params
,
sampling_params
,
sub_request_id
,
sub_request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
priority
=
priority
,
priority
=
priority
,
prompt_text
=
prompt_text
,
tokenization_kwargs
=
tokenization_kwargs
,
)
)
async
for
res
in
generator
:
async
for
res
in
generator
:
...
@@ -1154,11 +1110,11 @@ class OpenAIServing:
...
@@ -1154,11 +1110,11 @@ class OpenAIServing:
# Render the next prompt token ids and update sampling_params.
# Render the next prompt token ids and update sampling_params.
if
isinstance
(
context
,
(
HarmonyContext
,
StreamingHarmonyContext
)):
if
isinstance
(
context
,
(
HarmonyContext
,
StreamingHarmonyContext
)):
token_ids
=
context
.
render_for_completion
()
token_ids
=
context
.
render_for_completion
()
engine_prompt
=
TokensPrompt
(
prompt_
token_i
ds
=
token_ids
)
engine_prompt
=
token_i
nputs
(
token_ids
)
sampling_params
.
max_tokens
=
max_model_len
-
len
(
token_ids
)
sampling_params
.
max_tokens
=
max_model_len
-
len
(
token_ids
)
elif
isinstance
(
context
,
ParsableContext
):
elif
isinstance
(
context
,
ParsableContext
):
engine_prompt
s
=
await
self
.
_render_next_turn
(
(
engine_prompt
,)
=
await
self
.
_render_next_turn
(
context
.
request
,
context
.
request
,
context
.
parser
.
response_messages
,
context
.
parser
.
response_messages
,
context
.
tool_dicts
,
context
.
tool_dicts
,
...
@@ -1166,8 +1122,6 @@ class OpenAIServing:
...
@@ -1166,8 +1122,6 @@ class OpenAIServing:
context
.
chat_template
,
context
.
chat_template
,
context
.
chat_template_content_format
,
context
.
chat_template_content_format
,
)
)
engine_prompt
=
engine_prompts
[
0
]
prompt_text
=
self
.
_extract_prompt_text
(
engine_prompt
)
sampling_params
.
max_tokens
=
get_max_tokens
(
sampling_params
.
max_tokens
=
get_max_tokens
(
max_model_len
,
max_model_len
,
...
@@ -1184,7 +1138,7 @@ class OpenAIServing:
...
@@ -1184,7 +1138,7 @@ class OpenAIServing:
def
_log_inputs
(
def
_log_inputs
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
inputs
:
PromptType
|
TokPrompt
,
inputs
:
PromptType
|
ProcessorInputs
,
params
:
SamplingParams
|
PoolingParams
|
BeamSearchParams
|
None
,
params
:
SamplingParams
|
PoolingParams
|
BeamSearchParams
|
None
,
lora_request
:
LoRARequest
|
None
,
lora_request
:
LoRARequest
|
None
,
)
->
None
:
)
->
None
:
...
...
vllm/entrypoints/openai/realtime/serving.py
View file @
574fe752
...
@@ -15,6 +15,7 @@ from vllm.entrypoints.openai.models.serving import OpenAIServingModels
...
@@ -15,6 +15,7 @@ from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from
vllm.inputs.data
import
PromptType
from
vllm.inputs.data
import
PromptType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.interfaces
import
SupportsRealtime
from
vllm.model_executor.models.interfaces
import
SupportsRealtime
from
vllm.renderers.inputs.preprocess
import
parse_model_prompt
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -70,15 +71,20 @@ class OpenAIServingRealtime(OpenAIServing):
...
@@ -70,15 +71,20 @@ class OpenAIServingRealtime(OpenAIServing):
Yields:
Yields:
StreamingInput objects containing audio prompts for the engine
StreamingInput objects containing audio prompts for the engine
"""
"""
model_config
=
self
.
model_config
renderer
=
self
.
renderer
# mypy is being stupid
# mypy is being stupid
# TODO(Patrick) - fix this
# TODO(Patrick) - fix this
stream_input_iter
=
cast
(
stream_input_iter
=
cast
(
AsyncGenerator
[
PromptType
,
None
],
AsyncGenerator
[
PromptType
,
None
],
self
.
model_cls
.
buffer_realtime_audio
(
self
.
model_cls
.
buffer_realtime_audio
(
audio_stream
,
input_stream
,
self
.
model_config
audio_stream
,
input_stream
,
model_config
),
),
)
)
async
for
prompt
in
stream_input_iter
:
async
for
prompt
in
stream_input_iter
:
yield
StreamingInput
(
prompt
=
prompt
)
parsed_prompt
=
parse_model_prompt
(
model_config
,
prompt
)
(
engine_prompt
,)
=
await
renderer
.
render_cmpl_async
([
parsed_prompt
])
yield
StreamingInput
(
prompt
=
engine_prompt
)
vllm/entrypoints/openai/responses/context.py
View file @
574fe752
...
@@ -9,7 +9,7 @@ from abc import ABC, abstractmethod
...
@@ -9,7 +9,7 @@ from abc import ABC, abstractmethod
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
contextlib
import
AsyncExitStack
from
contextlib
import
AsyncExitStack
from
dataclasses
import
replace
from
dataclasses
import
replace
from
typing
import
TYPE_CHECKING
,
Union
from
typing
import
TYPE_CHECKING
,
Final
,
Union
from
openai.types.responses.response_function_tool_call_output_item
import
(
from
openai.types.responses.response_function_tool_call_output_item
import
(
ResponseFunctionToolCallOutputItem
,
ResponseFunctionToolCallOutputItem
,
...
@@ -304,7 +304,7 @@ class ParsableContext(ConversationContext):
...
@@ -304,7 +304,7 @@ class ParsableContext(ConversationContext):
self
.
tool_dicts
=
construct_tool_dicts
(
request
.
tools
,
request
.
tool_choice
)
self
.
tool_dicts
=
construct_tool_dicts
(
request
.
tools
,
request
.
tool_choice
)
self
.
chat_template
=
chat_template
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
=
chat_template_content_format
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
input_messages
:
list
[
ResponseRawMessageAndToken
]
=
[]
self
.
input_messages
:
list
[
ResponseRawMessageAndToken
]
=
[]
self
.
output_messages
:
list
[
ResponseRawMessageAndToken
]
=
[]
self
.
output_messages
:
list
[
ResponseRawMessageAndToken
]
=
[]
...
...
vllm/entrypoints/openai/responses/serving.py
View file @
574fe752
...
@@ -116,13 +116,12 @@ from vllm.entrypoints.openai.responses.utils import (
...
@@ -116,13 +116,12 @@ from vllm.entrypoints.openai.responses.utils import (
)
)
from
vllm.entrypoints.utils
import
get_max_tokens
from
vllm.entrypoints.utils
import
get_max_tokens
from
vllm.exceptions
import
VLLMValidationError
from
vllm.exceptions
import
VLLMValidationError
from
vllm.inputs.data
import
TokensPrompt
from
vllm.inputs.data
import
ProcessorInputs
,
token_inputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
as
SampleLogprob
from
vllm.logprobs
import
Logprob
as
SampleLogprob
from
vllm.logprobs
import
SampleLogprobs
from
vllm.logprobs
import
SampleLogprobs
from
vllm.outputs
import
CompletionOutput
from
vllm.outputs
import
CompletionOutput
from
vllm.parser
import
ParserManager
from
vllm.parser
import
ParserManager
from
vllm.renderers.inputs
import
TokPrompt
from
vllm.sampling_params
import
SamplingParams
,
StructuredOutputsParams
from
vllm.sampling_params
import
SamplingParams
,
StructuredOutputsParams
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
...
@@ -298,7 +297,7 @@ class OpenAIServingResponses(OpenAIServing):
...
@@ -298,7 +297,7 @@ class OpenAIServingResponses(OpenAIServing):
def
_validate_generator_input
(
def
_validate_generator_input
(
self
,
self
,
engine_prompt
:
TokPrompt
,
engine_prompt
:
ProcessorInputs
,
)
->
ErrorResponse
|
None
:
)
->
ErrorResponse
|
None
:
"""Add validations to the input to the generator here."""
"""Add validations to the input to the generator here."""
prompt_len
=
self
.
_extract_prompt_len
(
engine_prompt
)
prompt_len
=
self
.
_extract_prompt_len
(
engine_prompt
)
...
@@ -458,7 +457,6 @@ class OpenAIServingResponses(OpenAIServing):
...
@@ -458,7 +457,6 @@ class OpenAIServingResponses(OpenAIServing):
sampling_params
=
request
.
to_sampling_params
(
sampling_params
=
request
.
to_sampling_params
(
default_max_tokens
,
self
.
default_sampling_params
default_max_tokens
,
self
.
default_sampling_params
)
)
tok_params
=
request
.
build_tok_params
(
self
.
model_config
)
trace_headers
=
(
trace_headers
=
(
None
None
...
@@ -512,7 +510,6 @@ class OpenAIServingResponses(OpenAIServing):
...
@@ -512,7 +510,6 @@ class OpenAIServingResponses(OpenAIServing):
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
engine_prompt
=
engine_prompt
,
engine_prompt
=
engine_prompt
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
tok_params
=
tok_params
,
context
=
context
,
context
=
context
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
priority
=
request
.
priority
,
priority
=
request
.
priority
,
...
@@ -647,7 +644,7 @@ class OpenAIServingResponses(OpenAIServing):
...
@@ -647,7 +644,7 @@ class OpenAIServingResponses(OpenAIServing):
messages
=
self
.
_construct_input_messages_with_harmony
(
request
,
prev_response
)
messages
=
self
.
_construct_input_messages_with_harmony
(
request
,
prev_response
)
prompt_token_ids
=
render_for_completion
(
messages
)
prompt_token_ids
=
render_for_completion
(
messages
)
engine_prompt
=
TokensPrompt
(
prompt_
token_i
ds
=
prompt_token_ids
)
engine_prompt
=
token_i
nputs
(
prompt_token_ids
)
# Add cache_salt if provided in the request
# Add cache_salt if provided in the request
if
request
.
cache_salt
is
not
None
:
if
request
.
cache_salt
is
not
None
:
...
...
vllm/entrypoints/openai/speech_to_text/speech_to_text.py
View file @
574fe752
...
@@ -36,14 +36,15 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
...
@@ -36,14 +36,15 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranslationSegment
,
TranslationSegment
,
TranslationStreamResponse
,
TranslationStreamResponse
,
)
)
from
vllm.entrypoints.utils
import
get_max_tokens
from
vllm.exceptions
import
VLLMValidationError
from
vllm.exceptions
import
VLLMValidationError
from
vllm.inputs
.data
import
Pro
mptType
from
vllm.inputs
import
Pro
cessorInputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
FlatLogprobs
,
Logprob
from
vllm.logprobs
import
FlatLogprobs
,
Logprob
from
vllm.model_executor.models
import
SupportsTranscription
,
supports_transcription
from
vllm.model_executor.models
import
SupportsTranscription
,
supports_transcription
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.renderers.inputs
import
EncoderDecoderDictPrompt
from
vllm.renderers.inputs
import
DictPrompt
,
EncoderDecoderDictPrompt
from
vllm.renderers.inputs.preprocess
import
parse_enc_dec_prompt
from
vllm.renderers.inputs.preprocess
import
parse_enc_dec_prompt
,
parse_model_prompt
from
vllm.tokenizers
import
get_tokenizer
from
vllm.tokenizers
import
get_tokenizer
from
vllm.utils.import_utils
import
PlaceholderModule
from
vllm.utils.import_utils
import
PlaceholderModule
...
@@ -202,8 +203,6 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -202,8 +203,6 @@ class OpenAISpeechToText(OpenAIServing):
return
return
try
:
try
:
from
vllm.sampling_params
import
SamplingParams
warmup_start
=
time
.
perf_counter
()
warmup_start
=
time
.
perf_counter
()
logger
.
info
(
"Warming up multimodal input processor..."
)
logger
.
info
(
"Warming up multimodal input processor..."
)
...
@@ -221,21 +220,11 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -221,21 +220,11 @@ class OpenAISpeechToText(OpenAIServing):
request_prompt
=
""
,
request_prompt
=
""
,
to_language
=
None
,
to_language
=
None
,
)
)
parsed_prompt
=
parse_model_prompt
(
self
.
model_config
,
dummy_prompt
)
# Create minimal sampling params
dummy_params
=
SamplingParams
(
max_tokens
=
1
,
temperature
=
0.0
,
skip_clone
=
True
,
# Internal warmup, safe to skip clone
)
# Process the dummy input through the input processor
# Process the dummy input through the input processor
# This will trigger all the multimodal processing initialization
# This will trigger all the multimodal processing initialization
_
=
self
.
input_processor
.
process_inputs
(
_
=
self
.
renderer
.
render_cmpl
([
parsed_prompt
])
request_id
=
"warmup"
,
prompt
=
dummy_prompt
,
params
=
dummy_params
,
)
warmup_elapsed
=
time
.
perf_counter
()
-
warmup_start
warmup_elapsed
=
time
.
perf_counter
()
-
warmup_start
logger
.
info
(
"Input processor warmup completed in %.2fs"
,
warmup_elapsed
)
logger
.
info
(
"Input processor warmup completed in %.2fs"
,
warmup_elapsed
)
...
@@ -257,7 +246,7 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -257,7 +246,7 @@ class OpenAISpeechToText(OpenAIServing):
self
,
self
,
request
:
SpeechToTextRequest
,
request
:
SpeechToTextRequest
,
audio_data
:
bytes
,
audio_data
:
bytes
,
)
->
tuple
[
list
[
Pro
mptType
],
float
]:
)
->
tuple
[
list
[
Pro
cessorInputs
],
float
]:
# Validate request
# Validate request
language
=
self
.
model_cls
.
validate_language
(
request
.
language
)
language
=
self
.
model_cls
.
validate_language
(
request
.
language
)
# Skip to_language validation to avoid extra logging for Whisper.
# Skip to_language validation to avoid extra logging for Whisper.
...
@@ -285,7 +274,7 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -285,7 +274,7 @@ class OpenAISpeechToText(OpenAIServing):
and
duration
>
self
.
asr_config
.
max_audio_clip_s
and
duration
>
self
.
asr_config
.
max_audio_clip_s
)
)
chunks
=
[
y
]
if
not
do_split_audio
else
self
.
_split_audio
(
y
,
int
(
sr
))
chunks
=
[
y
]
if
not
do_split_audio
else
self
.
_split_audio
(
y
,
int
(
sr
))
p
rompts
=
[]
p
arsed_prompts
:
list
[
DictPrompt
]
=
[]
for
chunk
in
chunks
:
for
chunk
in
chunks
:
# The model has control over the construction, as long as it
# The model has control over the construction, as long as it
# returns a valid PromptType.
# returns a valid PromptType.
...
@@ -298,12 +287,19 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -298,12 +287,19 @@ class OpenAISpeechToText(OpenAIServing):
request_prompt
=
request
.
prompt
,
request_prompt
=
request
.
prompt
,
to_language
=
to_language
,
to_language
=
to_language
,
)
)
parsed_prompt
:
DictPrompt
if
request
.
response_format
==
"verbose_json"
:
if
request
.
response_format
==
"verbose_json"
:
prompt
=
self
.
_preprocess_verbose_prompt
(
parse_enc_dec_prompt
(
prompt
))
parsed_prompt
=
parse_enc_dec_prompt
(
prompt
)
parsed_prompt
=
self
.
_preprocess_verbose_prompt
(
parsed_prompt
)
else
:
parsed_prompt
=
parse_model_prompt
(
self
.
model_config
,
prompt
)
parsed_prompts
.
append
(
parsed_prompt
)
prompts
.
append
(
prompt
)
engine_prompts
=
await
self
.
renderer
.
render_cmpl_async
(
parsed_
prompt
s
)
return
prompts
,
duration
return
engine_
prompts
,
duration
def
_preprocess_verbose_prompt
(
self
,
prompt
:
EncoderDecoderDictPrompt
):
def
_preprocess_verbose_prompt
(
self
,
prompt
:
EncoderDecoderDictPrompt
):
dec_prompt
=
prompt
[
"decoder_prompt"
]
dec_prompt
=
prompt
[
"decoder_prompt"
]
...
@@ -436,7 +432,7 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -436,7 +432,7 @@ class OpenAISpeechToText(OpenAIServing):
try
:
try
:
lora_request
=
self
.
_maybe_get_adapters
(
request
)
lora_request
=
self
.
_maybe_get_adapters
(
request
)
prompts
,
duration_s
=
await
self
.
_preprocess_speech_to_text
(
engine_
prompts
,
duration_s
=
await
self
.
_preprocess_speech_to_text
(
request
=
request
,
request
=
request
,
audio_data
=
audio_data
,
audio_data
=
audio_data
,
)
)
...
@@ -445,57 +441,54 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -445,57 +441,54 @@ class OpenAISpeechToText(OpenAIServing):
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
return
self
.
create_error_response
(
e
)
return
self
.
create_error_response
(
e
)
# Schedule the request and get the result generator.
max_model_len
=
self
.
model_config
.
max_model_len
list_result_generator
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
|
None
=
None
list_result_generator
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
|
None
=
None
try
:
try
:
# Unlike most decoder-only models, whisper generation length is not
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
# constrained by the size of the input audio, which is mapped to a
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
# generated by respecting the extra completion tokens arg.
# generated by respecting the extra completion tokens arg.
if
request
.
max_completion_tokens
is
None
:
max_tokens
=
get_max_tokens
(
default_max_tokens
=
self
.
model_config
.
max_model_len
max_model_len
,
else
:
request
.
max_completion_tokens
,
default_max_tokens
=
min
(
0
,
self
.
model_config
.
max_model_len
,
request
.
max_completion_tokens
self
.
default_sampling_params
,
)
)
sampling_params
=
request
.
to_sampling_params
(
sampling_params
=
request
.
to_sampling_params
(
default_max_tokens
,
self
.
default_sampling_params
max_tokens
,
self
.
default_sampling_params
,
)
)
if
request
.
response_format
==
"verbose_json"
:
if
request
.
response_format
==
"verbose_json"
:
sampling_params
.
logprobs
=
1
sampling_params
.
logprobs
=
1
self
.
_log_inputs
(
request_id
,
# It will not display special tokens like <|startoftranscript|>
request
.
prompt
,
params
=
sampling_params
,
lora_request
=
lora_request
,
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
)
)
list_result_generator
=
[]
list_result_generator
=
[]
for
i
,
prompt
in
enumerate
(
prompts
):
for
i
,
engine_
prompt
in
enumerate
(
engine_
prompts
):
request_id_item
=
f
"
{
request_id
}
_
{
i
}
"
request_id_item
=
f
"
{
request_id
}
_
{
i
}
"
engine_request
=
self
.
input_processor
.
process_inputs
(
self
.
_log_inputs
(
request_id_item
,
request_id_item
,
prompt
,
engine_prompt
,
params
=
sampling_params
,
lora_request
=
lora_request
,
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
)
)
generator
=
self
.
engine_client
.
generate
(
engine_prompt
,
sampling_params
,
sampling_params
,
request_id_item
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
priority
=
0
,
)
list_result_generator
.
append
(
self
.
engine_client
.
generate
(
engine_request
,
sampling_params
,
request_id_item
,
lora_request
=
lora_request
,
)
)
)
list_result_generator
.
append
(
generator
)
except
ValueError
as
e
:
except
ValueError
as
e
:
return
self
.
create_error_response
(
e
)
return
self
.
create_error_response
(
e
)
...
...
vllm/entrypoints/pooling/embed/serving.py
View file @
574fe752
...
@@ -28,11 +28,10 @@ from vllm.entrypoints.pooling.utils import (
...
@@ -28,11 +28,10 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64
,
encode_pooling_output_base64
,
encode_pooling_output_float
,
encode_pooling_output_float
,
)
)
from
vllm.inputs.data
import
TokensPrompt
from
vllm.inputs.data
import
ProcessorInputs
,
TokensPrompt
,
token_inputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingOutput
,
PoolingRequestOutput
from
vllm.outputs
import
PoolingOutput
,
PoolingRequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.renderers.inputs
import
TokPrompt
from
vllm.utils.async_utils
import
merge_async_iterators
from
vllm.utils.async_utils
import
merge_async_iterators
from
vllm.utils.collection_utils
import
chunk_list
from
vllm.utils.collection_utils
import
chunk_list
from
vllm.utils.serial_utils
import
EmbedDType
,
Endianness
from
vllm.utils.serial_utils
import
EmbedDType
,
Endianness
...
@@ -256,7 +255,7 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -256,7 +255,7 @@ class OpenAIServingEmbedding(OpenAIServing):
chunk_request_id
=
f
"
{
ctx
.
request_id
}
-prompt-
{
prompt_idx
}
-chunk-
{
chunk_idx
}
"
chunk_request_id
=
f
"
{
ctx
.
request_id
}
-prompt-
{
prompt_idx
}
-chunk-
{
chunk_idx
}
"
# Create engine prompt for this chunk
# Create engine prompt for this chunk
chunk_engine_prompt
=
TokensPrompt
(
prompt_
token_i
ds
=
chunk_tokens
)
chunk_engine_prompt
=
token_i
nputs
(
chunk_tokens
)
# Log the chunk
# Log the chunk
self
.
_log_inputs
(
self
.
_log_inputs
(
...
@@ -266,16 +265,12 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -266,16 +265,12 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request
=
ctx
.
lora_request
,
lora_request
=
ctx
.
lora_request
,
)
)
tok_params
=
ctx
.
request
.
build_tok_params
(
self
.
model_config
)
tokenization_kwargs
=
tok_params
.
get_encode_kwargs
()
# Create generator for this chunk and wrap it to return indices
# Create generator for this chunk and wrap it to return indices
original_generator
=
self
.
engine_client
.
encode
(
original_generator
=
self
.
engine_client
.
encode
(
chunk_engine_prompt
,
chunk_engine_prompt
,
pooling_params
,
pooling_params
,
chunk_request_id
,
chunk_request_id
,
lora_request
=
ctx
.
lora_request
,
lora_request
=
ctx
.
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
priority
=
ctx
.
request
.
priority
,
priority
=
ctx
.
request
.
priority
,
)
)
...
@@ -362,7 +357,7 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -362,7 +357,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async
def
_create_single_prompt_generator
(
async
def
_create_single_prompt_generator
(
self
,
self
,
ctx
:
EmbeddingServeContext
,
ctx
:
EmbeddingServeContext
,
engine_prompt
:
TokPrompt
,
engine_prompt
:
ProcessorInputs
,
pooling_params
:
PoolingParams
,
pooling_params
:
PoolingParams
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
,
prompt_index
:
int
,
prompt_index
:
int
,
...
@@ -377,16 +372,12 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -377,16 +372,12 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request
=
ctx
.
lora_request
,
lora_request
=
ctx
.
lora_request
,
)
)
tok_params
=
ctx
.
request
.
build_tok_params
(
self
.
model_config
)
tokenization_kwargs
=
tok_params
.
get_encode_kwargs
()
# Return the original generator without wrapping
# Return the original generator without wrapping
return
self
.
engine_client
.
encode
(
return
self
.
engine_client
.
encode
(
engine_prompt
,
engine_prompt
,
pooling_params
,
pooling_params
,
request_id_item
,
request_id_item
,
lora_request
=
ctx
.
lora_request
,
lora_request
=
ctx
.
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
priority
=
ctx
.
request
.
priority
,
priority
=
ctx
.
request
.
priority
,
)
)
...
...
vllm/entrypoints/pooling/pooling/serving.py
View file @
574fe752
...
@@ -33,10 +33,9 @@ from vllm.entrypoints.pooling.utils import (
...
@@ -33,10 +33,9 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64
,
encode_pooling_output_base64
,
encode_pooling_output_float
,
encode_pooling_output_float
,
)
)
from
vllm.inputs
import
Pro
mptType
from
vllm.inputs
import
Pro
cessorInputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingRequestOutput
from
vllm.outputs
import
PoolingRequestOutput
from
vllm.renderers.inputs
import
TokPrompt
from
vllm.renderers.inputs.preprocess
import
prompt_to_seq
from
vllm.renderers.inputs.preprocess
import
prompt_to_seq
from
vllm.utils.async_utils
import
merge_async_iterators
from
vllm.utils.async_utils
import
merge_async_iterators
from
vllm.utils.serial_utils
import
EmbedDType
,
EncodingFormat
,
Endianness
from
vllm.utils.serial_utils
import
EmbedDType
,
EncodingFormat
,
Endianness
...
@@ -93,7 +92,7 @@ class OpenAIServingPooling(OpenAIServing):
...
@@ -93,7 +92,7 @@ class OpenAIServingPooling(OpenAIServing):
"dimensions is currently not supported"
"dimensions is currently not supported"
)
)
engine_prompts
:
Sequence
[
Pro
mptType
|
TokPrompt
]
engine_prompts
:
Sequence
[
Pro
cessorInputs
]
if
use_io_processor
:
=
isinstance
(
request
,
IOProcessorRequest
):
if
use_io_processor
:
=
isinstance
(
request
,
IOProcessorRequest
):
if
self
.
io_processor
is
None
:
if
self
.
io_processor
is
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -152,9 +151,6 @@ class OpenAIServingPooling(OpenAIServing):
...
@@ -152,9 +151,6 @@ class OpenAIServingPooling(OpenAIServing):
else
:
else
:
pooling_params
=
request
.
to_pooling_params
()
# type: ignore
pooling_params
=
request
.
to_pooling_params
()
# type: ignore
tok_params
=
request
.
build_tok_params
(
self
.
model_config
)
tokenization_kwargs
=
tok_params
.
get_encode_kwargs
()
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
...
@@ -176,7 +172,6 @@ class OpenAIServingPooling(OpenAIServing):
...
@@ -176,7 +172,6 @@ class OpenAIServingPooling(OpenAIServing):
pooling_params
,
pooling_params
,
request_id_item
,
request_id_item
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
priority
=
request
.
priority
,
)
)
...
...
vllm/entrypoints/pooling/score/serving.py
View file @
574fe752
...
@@ -35,7 +35,7 @@ from vllm.entrypoints.pooling.score.utils import (
...
@@ -35,7 +35,7 @@ from vllm.entrypoints.pooling.score.utils import (
get_score_prompt
,
get_score_prompt
,
validate_score_input
,
validate_score_input
,
)
)
from
vllm.inputs.data
import
TokensPrompt
from
vllm.inputs.data
import
ProcessorInputs
,
TokensPrompt
,
token_inputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
PoolingRequestOutput
,
ScoringRequestOutput
from
vllm.outputs
import
PoolingRequestOutput
,
ScoringRequestOutput
...
@@ -108,12 +108,15 @@ class ServingScores(OpenAIServing):
...
@@ -108,12 +108,15 @@ class ServingScores(OpenAIServing):
*
(
encode_async
(
t
,
**
tokenization_kwargs
)
for
t
in
input_texts
)
*
(
encode_async
(
t
,
**
tokenization_kwargs
)
for
t
in
input_texts
)
)
)
engine_prompts
:
list
[
TokensPrompt
]
=
[]
engine_prompts
:
list
[
ProcessorInputs
]
=
[]
for
tok_result
,
input_text
in
zip
(
tokenized_prompts
,
input_texts
):
for
tok_result
,
input_text
in
zip
(
tokenized_prompts
,
input_texts
):
text_token_prompt
=
self
.
_validate_input
(
request
,
tok_result
,
input_text
)
text_token_prompt
=
self
.
_validate_input
(
request
,
tok_result
,
input_text
)
engine_prompts
.
append
(
engine_prompts
.
append
(
TokensPrompt
(
prompt_token_ids
=
text_token_prompt
[
"prompt_token_ids"
])
token_inputs
(
text_token_prompt
[
"prompt_token_ids"
],
prompt
=
input_text
,
)
)
)
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
...
@@ -125,7 +128,7 @@ class ServingScores(OpenAIServing):
...
@@ -125,7 +128,7 @@ class ServingScores(OpenAIServing):
self
.
_log_inputs
(
self
.
_log_inputs
(
request_id_item
,
request_id_item
,
input_texts
[
i
]
,
engine_prompt
,
params
=
pooling_params
,
params
=
pooling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
...
@@ -207,12 +210,15 @@ class ServingScores(OpenAIServing):
...
@@ -207,12 +210,15 @@ class ServingScores(OpenAIServing):
*
(
encode_async
(
t
,
**
tokenization_kwargs
)
for
t
in
input_texts
)
*
(
encode_async
(
t
,
**
tokenization_kwargs
)
for
t
in
input_texts
)
)
)
engine_prompts
:
list
[
TokensPrompt
]
=
[]
engine_prompts
:
list
[
ProcessorInputs
]
=
[]
for
tok_result
,
input_text
in
zip
(
tokenized_prompts
,
input_texts
):
for
tok_result
,
input_text
in
zip
(
tokenized_prompts
,
input_texts
):
text_token_prompt
=
self
.
_validate_input
(
request
,
tok_result
,
input_text
)
text_token_prompt
=
self
.
_validate_input
(
request
,
tok_result
,
input_text
)
engine_prompts
.
append
(
engine_prompts
.
append
(
TokensPrompt
(
prompt_token_ids
=
text_token_prompt
[
"prompt_token_ids"
])
token_inputs
(
text_token_prompt
[
"prompt_token_ids"
],
prompt
=
input_text
,
)
)
)
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
...
@@ -225,7 +231,7 @@ class ServingScores(OpenAIServing):
...
@@ -225,7 +231,7 @@ class ServingScores(OpenAIServing):
self
.
_log_inputs
(
self
.
_log_inputs
(
request_id_item
,
request_id_item
,
input_texts
[
i
]
,
engine_prompt
,
params
=
pooling_params
,
params
=
pooling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
...
...
vllm/entrypoints/serve/disagg/serving.py
View file @
574fe752
...
@@ -29,7 +29,6 @@ from vllm.entrypoints.serve.disagg.protocol import (
...
@@ -29,7 +29,6 @@ from vllm.entrypoints.serve.disagg.protocol import (
GenerateResponse
,
GenerateResponse
,
GenerateResponseChoice
,
GenerateResponseChoice
,
)
)
from
vllm.inputs.data
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
from
vllm.logprobs
import
Logprob
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
...
@@ -116,7 +115,7 @@ class ServingTokens(OpenAIServing):
...
@@ -116,7 +115,7 @@ class ServingTokens(OpenAIServing):
self
.
_log_inputs
(
self
.
_log_inputs
(
request_id
,
request_id
,
TokensPrompt
(
prompt_token_ids
=
request
.
token_ids
)
,
engine_prompt
,
params
=
sampling_params
,
params
=
sampling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
...
@@ -127,27 +126,13 @@ class ServingTokens(OpenAIServing):
...
@@ -127,27 +126,13 @@ class ServingTokens(OpenAIServing):
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
)
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
)
)
)
tok_params
=
request
.
build_tok_params
(
self
.
model_config
)
tokenization_kwargs
=
tok_params
.
get_encode_kwargs
()
engine_request
=
self
.
input_processor
.
process_inputs
(
request_id
,
engine_prompt
,
sampling_params
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
result_generator
=
self
.
engine_client
.
generate
(
result_generator
=
self
.
engine_client
.
generate
(
engine_
reques
t
,
engine_
promp
t
,
sampling_params
,
sampling_params
,
request_id
,
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
priority
=
request
.
priority
,
tokenization_kwargs
=
tokenization_kwargs
,
)
)
except
ValueError
as
e
:
except
ValueError
as
e
:
...
...
vllm/entrypoints/serve/tokenize/serving.py
View file @
574fe752
...
@@ -20,7 +20,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
...
@@ -20,7 +20,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeResponse
,
TokenizeResponse
,
TokenizerInfoResponse
,
TokenizerInfoResponse
,
)
)
from
vllm.inputs
import
TokensPrompt
from
vllm.inputs
import
TokensPrompt
,
token_inputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
...
@@ -135,7 +135,7 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -135,7 +135,7 @@ class OpenAIServingTokenization(OpenAIServing):
self
.
_log_inputs
(
self
.
_log_inputs
(
request_id
,
request_id
,
TokensPrompt
(
prompt_
token_i
ds
=
request
.
tokens
),
token_i
nputs
(
request
.
tokens
),
params
=
None
,
params
=
None
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
...
...
vllm/inputs/data.py
View file @
574fe752
...
@@ -187,6 +187,9 @@ class _InputOptions(TypedDict):
...
@@ -187,6 +187,9 @@ class _InputOptions(TypedDict):
Additional options available to all input types.
Additional options available to all input types.
"""
"""
arrival_time
:
NotRequired
[
float
]
"""The time when the input was received (before rendering)."""
cache_salt
:
NotRequired
[
str
]
cache_salt
:
NotRequired
[
str
]
"""Optional cache salt to be used for prefix caching."""
"""Optional cache salt to be used for prefix caching."""
...
@@ -300,6 +303,9 @@ class EncoderDecoderInputs(TypedDict):
...
@@ -300,6 +303,9 @@ class EncoderDecoderInputs(TypedDict):
decoder_prompt
:
DecoderInputs
decoder_prompt
:
DecoderInputs
"""The inputs for the decoder portion."""
"""The inputs for the decoder portion."""
arrival_time
:
NotRequired
[
float
]
"""The time when the input was received (before rendering)."""
ProcessorInputs
:
TypeAlias
=
DecoderOnlyInputs
|
EncoderDecoderInputs
ProcessorInputs
:
TypeAlias
=
DecoderOnlyInputs
|
EncoderDecoderInputs
"""
"""
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment