Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
afd0da21
Commit
afd0da21
authored
Feb 03, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.7.1' into v0.7.1-dev
parents
1a11f127
4f4d427a
Changes
594
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1076 additions
and
286 deletions
+1076
-286
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+53
-9
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+32
-11
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+2
-2
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+2
-2
vllm/engine/protocol.py
vllm/engine/protocol.py
+10
-0
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+69
-81
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+204
-58
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+194
-59
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+60
-30
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+115
-25
vllm/entrypoints/openai/reasoning_parsers/__init__.py
vllm/entrypoints/openai/reasoning_parsers/__init__.py
+6
-0
vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py
...ypoints/openai/reasoning_parsers/abs_reasoning_parsers.py
+158
-0
vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
.../openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
+133
-0
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+38
-9
No files found.
Too many changes to show.
To preserve performance only
594 of 594+
files are displayed.
Plain diff
Email patch
vllm/engine/multiprocessing/client.py
View file @
afd0da21
...
...
@@ -25,7 +25,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_OUTPUT_EXT
,
RPC_REQUEST_T
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCError
,
RPCProcessRequest
,
RPCAdapterLoadedResponse
,
RPCError
,
RPCLoadAdapterRequest
,
RPCProcessRequest
,
RPCResetPrefixCacheRequest
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCUProfileRequest
)
from
vllm.engine.protocol
import
EngineClient
...
...
@@ -240,22 +243,34 @@ class MQLLMEngineClient(EngineClient):
queue
=
self
.
output_queues
.
get
(
request_id
)
if
queue
is
not
None
:
queue
.
put_nowait
(
exception
)
# Put each output into the appropriate queue.
elif
isinstance
(
request_outputs
,
RPCAdapterLoadedResponse
):
self
.
_add_output
(
request_outputs
)
else
:
# Put each output into the appropriate steam.
for
request_output
in
request_outputs
:
queue
=
self
.
output_queues
.
get
(
request_output
.
request_id
)
if
queue
is
not
None
:
queue
.
put_nowait
(
request_output
)
self
.
_add_output
(
request_output
)
except
asyncio
.
CancelledError
:
logger
.
debug
(
"Shutting down MQLLMEngineClient output handler."
)
def
_add_output
(
self
,
request_output
:
Union
[
RequestOutput
,
RPCAdapterLoadedResponse
]):
queue
=
self
.
output_queues
.
get
(
request_output
.
request_id
)
if
queue
is
not
None
:
queue
.
put_nowait
(
request_output
)
async
def
setup
(
self
):
"""Setup the client before it starts sending server requests."""
# Start output_loop
self
.
output_loop
=
asyncio
.
create_task
(
self
.
run_output_handler_loop
())
if
self
.
output_loop
is
None
:
# only generate once to avoid multiple concurrent output_loops
# this will lead to race conditions and wrong orders of tokens
# returned by the engine
# setup will be called multiple times during the startup of
# the engine
self
.
output_loop
=
asyncio
.
create_task
(
self
.
run_output_handler_loop
())
with
self
.
get_data_socket
()
as
socket
:
# Wait until server is ready.
...
...
@@ -264,6 +279,7 @@ class MQLLMEngineClient(EngineClient):
self
.
tracing_flag
=
response
.
tracing_enabled
# Start health_loop.
if
self
.
health_loop
is
None
:
self
.
health_loop
=
asyncio
.
create_task
(
self
.
run_heartbeat_loop
(
timeout
=
VLLM_RPC_TIMEOUT
))
...
...
@@ -659,3 +675,31 @@ class MQLLMEngineClient(EngineClient):
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUProfileRequest
.
STOP_PROFILE
,
socket
=
self
.
input_socket
)
async
def
reset_prefix_cache
(
self
)
->
None
:
"""Reset the prefix cache"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCResetPrefixCacheRequest
.
RESET_PREFIX_CACHE
,
socket
=
self
.
input_socket
)
async
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
"""Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests
request
=
RPCLoadAdapterRequest
(
lora_request
)
# Create output queue for this requests.
queue
:
asyncio
.
Queue
[
Union
[
None
,
BaseException
]]
=
asyncio
.
Queue
()
self
.
output_queues
[
request
.
request_id
]
=
queue
# Send the request
request_bytes
=
pickle
.
dumps
(
request
)
await
self
.
input_socket
.
send_multipart
((
request_bytes
,
),
copy
=
False
)
# Wait for the response
request_output
=
await
queue
.
get
()
self
.
output_queues
.
pop
(
request
.
request_id
)
# Raise on error, otherwise happily return None
if
isinstance
(
request_output
,
BaseException
):
raise
request_output
vllm/engine/multiprocessing/engine.py
View file @
afd0da21
...
...
@@ -14,11 +14,13 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_OUTPUT_EXT
,
REQUEST_OUTPUTS_T
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCError
,
RPCProcessRequest
,
RPCAdapterLoadedResponse
,
RPCError
,
RPCLoadAdapterRequest
,
RPCProcessRequest
,
RPCResetPrefixCacheRequest
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCUProfileRequest
)
# yapf: enable
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.usage.usage_lib
import
UsageContext
...
...
@@ -234,6 +236,10 @@ class MQLLMEngine:
self
.
start_profile
()
else
:
self
.
stop_profile
()
elif
isinstance
(
request
,
RPCLoadAdapterRequest
):
self
.
_handle_load_adapter_request
(
request
)
elif
isinstance
(
request
,
RPCResetPrefixCacheRequest
):
self
.
reset_prefix_cache
()
else
:
raise
ValueError
(
"Unknown RPCRequest Type: "
f
"
{
type
(
request
)
}
"
)
...
...
@@ -284,6 +290,20 @@ class MQLLMEngine:
if
self
.
log_requests
:
logger
.
info
(
"Aborted request %s."
,
request
.
request_id
)
def
_handle_load_adapter_request
(
self
,
request
:
RPCLoadAdapterRequest
):
try
:
self
.
engine
.
add_lora
(
request
.
lora_request
)
except
BaseException
as
e
:
# Send back an error if the adater fails to load
rpc_err
=
RPCError
(
request_id
=
request
.
request_id
,
is_engine_errored
=
False
,
exception
=
e
)
self
.
_send_outputs
(
rpc_err
)
return
# Otherwise, send back the successful load message
self
.
_send_outputs
(
RPCAdapterLoadedResponse
(
request_id
=
request
.
request_id
))
def
_health_check
(
self
):
# Send unhealthy if engine has already errored
if
self
.
_errored_with
is
not
None
:
...
...
@@ -296,7 +316,11 @@ class MQLLMEngine:
self
.
_send_unhealthy
(
e
)
def
_send_outputs
(
self
,
outputs
:
REQUEST_OUTPUTS_T
):
"""Send List of RequestOutput to RPCClient."""
"""Send outputs back to the engine client. These can be:
- Exceptions
- A list of generation outputs
- A response from loading a lora adapter
"""
if
outputs
:
try
:
from
ray.exceptions
import
RayTaskError
...
...
@@ -335,16 +359,13 @@ class MQLLMEngine:
self
.
_errored_with
=
e
def
start_profile
(
self
)
->
None
:
if
type
(
self
.
engine
.
model_executor
)
is
GPUExecutor
:
self
.
engine
.
model_executor
.
start_profile
()
else
:
self
.
engine
.
model_executor
.
_run_workers
(
"start_profile"
)
self
.
engine
.
start_profile
()
def
stop_profile
(
self
)
->
None
:
if
type
(
self
.
engine
.
model_executor
)
is
GPUExecutor
:
self
.
engine
.
model_executor
.
stop_profile
()
else
:
self
.
engine
.
model_executor
.
_run_workers
(
"stop
_pr
o
fi
le"
)
self
.
engine
.
stop_profile
()
def
reset_prefix_cache
(
self
)
->
bool
:
return
self
.
engine
.
reset
_pr
e
fi
x_cache
(
)
def
signal_handler
(
*
_
)
->
None
:
...
...
vllm/engine/output_processor/multi_step.py
View file @
afd0da21
...
...
@@ -65,7 +65,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
@
staticmethod
@
functools
.
lru_cache
def
_log_prompt_logprob_unsupported_warning_once
():
# Reminder: Please update docs/source/
usage
/compatibility_matrix.md
# Reminder: Please update docs/source/
features
/compatibility_matrix.md
# If the feature combo become valid
logger
.
warning
(
"Prompt logprob is not supported by multi step workers. "
...
...
@@ -144,7 +144,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def
_process_decode_and_stop
(
self
,
seq
:
Sequence
,
sampling_params
:
SamplingParams
)
->
None
:
new_char_count
=
0
if
sampling_params
.
detokenize
:
if
sampling_params
.
detokenize
and
self
.
detokenizer
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
sampling_params
)
...
...
vllm/engine/output_processor/single_step.py
View file @
afd0da21
...
...
@@ -102,9 +102,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
Args:
seq_group: the output is associated with this :class:`SequenceGroup`
output: the :class:`SequenceGroupOutput` for a single scheduler step
output
s
: the :class:`SequenceGroupOutput` for a single scheduler step
"""
assert
len
(
outputs
)
==
1
,
(
"Single step should only ha
s
1 output."
)
assert
len
(
outputs
)
==
1
,
"Single step should only ha
ve
1 output."
output
=
outputs
[
0
]
assert
isinstance
(
output
,
CompletionSequenceGroupOutput
)
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
...
...
vllm/engine/protocol.py
View file @
afd0da21
...
...
@@ -270,3 +270,13 @@ class EngineClient(ABC):
async
def
stop_profile
(
self
)
->
None
:
"""Start profiling the engine"""
...
@
abstractmethod
async
def
reset_prefix_cache
(
self
)
->
None
:
"""Reset the prefix cache"""
...
@
abstractmethod
async
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
"""Load a new LoRA adapter into the engine for future requests."""
...
vllm/entrypoints/chat_utils.py
View file @
afd0da21
...
...
@@ -3,10 +3,10 @@ import codecs
import
json
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
,
deque
from
functools
import
lru_cache
,
partial
from
functools
import
cache
,
lru_cache
,
partial
from
pathlib
import
Path
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Dict
,
Generic
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypeVar
,
Union
,
cast
)
Literal
,
Optional
,
Tuple
,
TypeVar
,
Union
,
cast
)
import
jinja2.nodes
import
transformers.utils.chat_template_utils
as
hf_chat_utils
...
...
@@ -23,6 +23,8 @@ from openai.types.chat import (
ChatCompletionMessageParam
as
OpenAIChatCompletionMessageParam
)
from
openai.types.chat
import
(
ChatCompletionMessageToolCallParam
,
ChatCompletionToolMessageParam
)
from
openai.types.chat.chat_completion_content_part_input_audio_param
import
(
InputAudio
)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
...
...
@@ -31,13 +33,8 @@ from typing_extensions import Required, TypeAlias, TypedDict
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal.utils
import
(
async_get_and_parse_audio
,
async_get_and_parse_image
,
async_get_and_parse_video
,
get_and_parse_audio
,
get_and_parse_image
,
get_and_parse_video
)
from
vllm.multimodal.utils
import
MediaConnector
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
print_warning_once
logger
=
init_logger
(
__name__
)
...
...
@@ -368,16 +365,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self
.
_tokenizer
=
tokenizer
self
.
_allowed_items
=
(
model_config
.
multimodal_config
.
limit_per_prompt
if
model_config
.
multimodal_config
else
{})
self
.
_consumed_items
=
{
k
:
0
for
k
in
self
.
_allowed_items
}
self
.
_items
:
List
[
_T
]
=
[]
self
.
_items
_by_modality
=
defaultdict
[
str
,
list
[
_T
]](
list
)
@
property
def
model_config
(
self
)
->
ModelConfig
:
return
self
.
_model_config
@
property
def
allowed_local_media_path
(
self
):
return
self
.
_model_config
.
allowed_local_media_path
@
staticmethod
@
lru_
cache
(
maxsize
=
None
)
@
cache
def
_cached_token_str
(
tokenizer
:
AnyTokenizer
,
token_index
:
int
)
->
str
:
return
tokenizer
.
decode
(
token_index
)
...
...
@@ -392,7 +392,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if
model_type
==
"phi3_v"
:
# Workaround since this token is not defined in the tokenizer
return
f
"<|image_
{
current_count
}
|>"
if
model_type
==
"minicpmv"
:
if
model_type
in
(
"minicpmo"
,
"minicpmv"
)
:
return
"(<image>./</image>)"
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"paligemma"
,
"pixtral"
):
...
...
@@ -403,8 +403,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if
model_type
.
startswith
(
"llava"
):
return
self
.
_cached_token_str
(
self
.
_tokenizer
,
hf_config
.
image_token_index
)
if
model_type
in
(
"chameleon"
,
"internvl_chat"
,
"NVLM_D"
,
"h2ovl_chat"
):
if
model_type
in
(
"chameleon"
,
"deepseek_vl_v2"
,
"internvl_chat"
,
"NVLM_D"
,
"h2ovl_chat"
):
return
"<image>"
if
model_type
==
"mllama"
:
return
"<|image|>"
...
...
@@ -424,10 +424,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if
model_type
==
"qwen2_audio"
:
return
(
f
"Audio
{
current_count
}
: "
f
"<|audio_bos|><|AUDIO|><|audio_eos|>"
)
if
model_type
==
"minicpmo"
:
return
"(<audio>./</audio>)"
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
elif
modality
==
"video"
:
if
model_type
==
"qwen2_vl"
:
return
"<|vision_start|><|video_pad|><|vision_end|>"
if
model_type
in
(
"minicpmo"
,
"minicpmv"
):
return
"(<video>./</video>)"
if
model_type
.
startswith
(
"llava"
):
return
self
.
_cached_token_str
(
self
.
_tokenizer
,
hf_config
.
video_token_index
)
...
...
@@ -435,38 +439,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
else
:
raise
TypeError
(
f
"Unknown modality:
{
modality
}
"
)
@
staticmethod
def
_combine
(
items
:
List
[
MultiModalDataDict
])
->
MultiModalDataDict
:
mm_lists
:
Mapping
[
str
,
List
[
object
]]
=
defaultdict
(
list
)
# Merge all the multi-modal items
for
single_mm_data
in
items
:
for
mm_key
,
mm_item
in
single_mm_data
.
items
():
if
isinstance
(
mm_item
,
list
):
mm_lists
[
mm_key
].
extend
(
mm_item
)
else
:
mm_lists
[
mm_key
].
append
(
mm_item
)
# Unpack any single item lists for models that don't expect multiple.
return
{
mm_key
:
mm_list
[
0
]
if
len
(
mm_list
)
==
1
else
mm_list
for
mm_key
,
mm_list
in
mm_lists
.
items
()
}
def
add
(
self
,
modality
:
ModalityStr
,
item
:
_T
)
->
Optional
[
str
]:
"""
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count
=
self
.
_allowed_items
.
get
(
modality
,
1
)
current_count
=
self
.
_
consumed_items
.
get
(
modality
,
0
)
+
1
current_count
=
len
(
self
.
_
items_by_modality
[
modality
]
)
+
1
if
current_count
>
allowed_count
:
raise
ValueError
(
f
"At most
{
allowed_count
}
{
modality
}
(s) may be provided in "
"one request."
)
self
.
_consumed_items
[
modality
]
=
current_count
self
.
_items
.
append
(
item
)
self
.
_items_by_modality
[
modality
].
append
(
item
)
return
self
.
_placeholder_str
(
modality
,
current_count
)
...
...
@@ -475,22 +460,26 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise
NotImplementedError
class
MultiModalItemTracker
(
BaseMultiModalItemTracker
[
MultiModalDataDi
ct
]):
class
MultiModalItemTracker
(
BaseMultiModalItemTracker
[
obje
ct
]):
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
return
self
.
_combine
(
self
.
_items
)
if
self
.
_items
else
None
if
self
.
_items_by_modality
:
return
dict
(
self
.
_items_by_modality
)
return
None
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
return
MultiModalContentParser
(
self
)
class
AsyncMultiModalItemTracker
(
BaseMultiModalItemTracker
[
Awaitable
[
MultiModalDataDict
]]):
class
AsyncMultiModalItemTracker
(
BaseMultiModalItemTracker
[
Awaitable
[
object
]]):
async
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
if
self
.
_items
:
items
=
await
asyncio
.
gather
(
*
self
.
_items
)
return
self
.
_combine
(
items
)
if
self
.
_items_by_modality
:
return
{
modality
:
await
asyncio
.
gather
(
*
items
)
for
modality
,
items
in
self
.
_items_by_modality
.
items
()
}
return
None
...
...
@@ -522,7 +511,7 @@ class BaseMultiModalContentParser(ABC):
raise
NotImplementedError
@
abstractmethod
def
parse_input_audio
(
self
,
input_audio
:
Dict
[
str
,
str
]
)
->
None
:
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
)
->
None
:
raise
NotImplementedError
@
abstractmethod
...
...
@@ -537,31 +526,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self
.
_tracker
=
tracker
self
.
_connector
=
MediaConnector
(
allowed_local_media_path
=
tracker
.
allowed_local_media_path
,
)
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
image
=
get_and_parse_image
(
image_url
,
allowed_local_media_path
=
self
.
_tracker
.
_model_config
.
allowed_local_media_path
)
image
=
self
.
_connector
.
fetch_image
(
image_url
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image
)
self
.
_add_placeholder
(
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
audio
=
get_and_parse
_audio
(
audio_url
)
audio
=
self
.
_connector
.
fetch
_audio
(
audio_url
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio
)
self
.
_add_placeholder
(
placeholder
)
def
parse_input_audio
(
self
,
input_audio
:
Dict
[
str
,
str
])
->
None
:
input_audio_data
=
input_audio
.
get
(
"data"
,
""
)
input_audio_format
=
input_audio
.
get
(
"format"
,
""
)
audio_url
=
f
"data:audio/
{
input_audio_format
}
;base64,
{
input_audio_data
}
"
audio
=
get_and_parse_audio
(
audio_url
)
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
)
->
None
:
audio_data
=
input_audio
.
get
(
"data"
,
""
)
audio_format
=
input_audio
.
get
(
"format"
,
""
)
audio_url
=
f
"data:audio/
{
audio_format
}
;base64,
{
audio_data
}
"
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio
)
self
.
_add_placeholder
(
placeholder
)
return
self
.
parse_audio
(
audio_url
)
def
parse_video
(
self
,
video_url
:
str
)
->
None
:
video
=
get_and_parse
_video
(
video_url
)
video
=
self
.
_connector
.
fetch
_video
(
video_url
)
placeholder
=
self
.
_tracker
.
add
(
"video"
,
video
)
self
.
_add_placeholder
(
placeholder
)
...
...
@@ -573,33 +562,31 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
super
().
__init__
()
self
.
_tracker
=
tracker
self
.
_connector
=
MediaConnector
(
allowed_local_media_path
=
tracker
.
allowed_local_media_path
,
)
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
image_coro
=
async_get_and_parse_image
(
image_url
,
allowed_local_media_path
=
self
.
_tracker
.
_model_config
.
allowed_local_media_path
)
image_coro
=
self
.
_connector
.
fetch_image_async
(
image_url
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_coro
)
self
.
_add_placeholder
(
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
audio_coro
=
async_get_and_parse_audio
(
audio_url
)
audio_coro
=
self
.
_connector
.
fetch_audio_async
(
audio_url
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio_coro
)
self
.
_add_placeholder
(
placeholder
)
def
parse_input_audio
(
self
,
input_audio
:
Dict
[
str
,
str
])
->
None
:
input_audio_data
=
input_audio
.
get
(
"data"
,
""
)
input_audio_format
=
input_audio
.
get
(
"format"
,
""
)
audio_url
=
f
"data:audio/
{
input_audio_format
}
;base64,
{
input_audio_data
}
"
audio_coro
=
async_get_and_parse_audio
(
audio_url
)
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
)
->
None
:
audio_data
=
input_audio
.
get
(
"data"
,
""
)
audio_format
=
input_audio
.
get
(
"format"
,
""
)
audio_url
=
f
"data:audio/
{
audio_format
}
;base64,
{
audio_data
}
"
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio_coro
)
self
.
_add_placeholder
(
placeholder
)
return
self
.
parse_audio
(
audio_url
)
def
parse_video
(
self
,
video_url
:
str
)
->
None
:
video
=
async_get_and_parse_video
(
video_url
)
video
=
self
.
_connector
.
fetch_video_async
(
video_url
)
placeholder
=
self
.
_tracker
.
add
(
"video"
,
video
)
self
.
_add_placeholder
(
placeholder
)
...
...
@@ -695,10 +682,13 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser
=
partial
(
cast
,
ChatCompletionContentPartRefusalParam
)
_VideoParser
=
partial
(
cast
,
ChatCompletionContentPartVideoParam
)
_ContentPart
:
TypeAlias
=
Union
[
str
,
Dict
[
str
,
str
],
InputAudio
]
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP
:
Dict
[
str
,
Callable
[[
ChatCompletionContentPartParam
],
Union
[
str
,
Dict
[
str
,
str
]]]]
=
{
MM_PARSER_MAP
:
Dict
[
str
,
Callable
[[
ChatCompletionContentPartParam
],
_ContentPart
],
]
=
{
"text"
:
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
""
),
"image_url"
:
...
...
@@ -715,8 +705,7 @@ MM_PARSER_MAP: Dict[str,
def
_parse_chat_message_content_mm_part
(
part
:
ChatCompletionContentPartParam
)
->
Tuple
[
str
,
Union
[
str
,
Dict
[
str
,
str
]]]:
part
:
ChatCompletionContentPartParam
)
->
tuple
[
str
,
_ContentPart
]:
"""
Parses a given multi-modal content part based on its type.
...
...
@@ -783,7 +772,7 @@ def _parse_chat_message_content_parts(
*
,
wrap_dicts
:
bool
,
)
->
List
[
ConversationMessage
]:
content
:
List
[
Union
[
str
,
Dict
[
str
,
str
]]]
=
[]
content
=
list
[
_ContentPart
]()
mm_parser
=
mm_tracker
.
create_parser
()
...
...
@@ -814,7 +803,7 @@ def _parse_chat_message_content_part(
mm_parser
:
BaseMultiModalContentParser
,
*
,
wrap_dicts
:
bool
,
)
->
Optional
[
Union
[
str
,
Dict
[
str
,
str
]]
]:
)
->
Optional
[
_ContentPart
]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
...
...
@@ -823,8 +812,7 @@ def _parse_chat_message_content_part(
with multimodal placeholders.
"""
if
isinstance
(
part
,
str
):
# Handle plain text parts
text
=
_TextParser
(
part
)
return
text
return
part
# Handle structured dictionary parts
part_type
,
content
=
_parse_chat_message_content_mm_part
(
part
)
...
...
@@ -855,7 +843,7 @@ def _parse_chat_message_content_part(
return
{
'type'
:
'audio'
}
if
wrap_dicts
else
None
if
part_type
==
"input_audio"
:
dict_content
=
cast
(
Dict
[
str
,
str
]
,
content
)
dict_content
=
cast
(
InputAudio
,
content
)
mm_parser
.
parse_input_audio
(
dict_content
)
return
{
'type'
:
'audio'
}
if
wrap_dicts
else
None
...
...
@@ -1000,14 +988,14 @@ def apply_mistral_chat_template(
**
kwargs
:
Any
,
)
->
List
[
int
]:
if
chat_template
is
not
None
:
print_
warning_once
(
logger
.
warning_once
(
"'chat_template' cannot be overridden for mistral tokenizer."
)
if
"add_generation_prompt"
in
kwargs
:
print_
warning_once
(
logger
.
warning_once
(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored."
)
if
"continue_final_message"
in
kwargs
:
print_
warning_once
(
logger
.
warning_once
(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored."
)
...
...
vllm/entrypoints/llm.py
View file @
afd0da21
import
itertools
import
warnings
from
contextlib
import
contextmanager
from
typing
import
(
Any
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
,
cast
,
overload
)
from
typing
import
(
Any
,
Callable
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
,
cast
,
overload
)
import
cloudpickle
import
torch
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
typing_extensions
import
deprecated
from
typing_extensions
import
TypeVar
,
deprecated
from
vllm
import
envs
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
...
...
@@ -21,7 +24,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
parse_chat_messages
,
resolve_chat_template_content_format
)
from
vllm.inputs
import
PromptType
,
SingletonPrompt
,
TextPrompt
,
TokensPrompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.inputs.parse
import
is_token_prompt
,
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding.guided_fields
import
(
...
...
@@ -41,6 +44,8 @@ from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
logger
=
init_logger
(
__name__
)
_R
=
TypeVar
(
"_R"
,
default
=
Any
)
class
LLM
:
"""An LLM for generating texts from given prompts and sampling parameters.
...
...
@@ -186,6 +191,13 @@ class LLM:
if
"disable_log_stats"
not
in
kwargs
:
kwargs
[
"disable_log_stats"
]
=
True
if
"worker_cls"
in
kwargs
:
worker_cls
=
kwargs
[
"worker_cls"
]
# if the worker_cls is not qualified string name,
# we serialize it using cloudpickle to avoid pickling issues
if
isinstance
(
worker_cls
,
type
):
kwargs
[
"worker_cls"
]
=
cloudpickle
.
dumps
(
worker_cls
)
if
compilation_config
is
not
None
:
if
isinstance
(
compilation_config
,
(
int
,
dict
)):
compilation_config_instance
=
CompilationConfig
.
from_cli
(
...
...
@@ -225,18 +237,11 @@ class LLM:
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
self
.
engine_class
=
self
.
get_engine_class
()
# TODO(rob): enable mp by default (issue with fork vs spawn)
self
.
llm_engine
=
self
.
engine_class
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
self
.
request_counter
=
Counter
()
def
__del__
(
self
):
if
hasattr
(
self
,
'llm_engine'
)
and
self
.
llm_engine
and
hasattr
(
self
.
llm_engine
,
"shutdown"
):
self
.
llm_engine
.
shutdown
()
@
staticmethod
def
get_engine_class
()
->
Type
[
LLMEngine
]:
if
envs
.
VLLM_USE_V1
:
...
...
@@ -462,9 +467,47 @@ class LLM:
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
List
[
_R
]:
"""
Execute an RPC call on all workers.
Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
executor
=
self
.
llm_engine
.
model_executor
return
executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
"""
Run a function directly on the model inside each worker,
returning the result for each of them.
"""
executor
=
self
.
llm_engine
.
model_executor
return
executor
.
apply_model
(
func
)
def
beam_search
(
self
,
prompts
:
List
[
Union
[
str
,
List
[
int
]
]],
prompts
:
List
[
Union
[
TokensPrompt
,
TextPrompt
]],
params
:
BeamSearchParams
,
)
->
List
[
BeamSearchOutput
]:
"""
...
...
@@ -500,8 +543,10 @@ class LLM:
instances
:
List
[
BeamSearchInstance
]
=
[]
for
prompt
in
prompts
:
prompt_tokens
=
prompt
if
isinstance
(
prompt
,
list
)
else
tokenizer
.
encode
(
prompt
)
if
is_token_prompt
(
prompt
):
prompt_tokens
=
prompt
[
"prompt_token_ids"
]
else
:
prompt_tokens
=
tokenizer
.
encode
(
prompt
[
"prompt"
])
instances
.
append
(
BeamSearchInstance
(
prompt_tokens
))
for
_
in
range
(
max_tokens
):
...
...
@@ -952,6 +997,107 @@ class LLM:
return
[
ClassificationRequestOutput
.
from_base
(
item
)
for
item
in
items
]
def
_embedding_score
(
self
,
tokenizer
:
AnyTokenizer
,
text_1
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]],
text_2
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]],
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
List
[
ScoringRequestOutput
]:
encoded_output
=
self
.
encode
(
text_1
+
text_2
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
encoded_output_1
=
encoded_output
[
0
:
len
(
text_1
)]
encoded_output_2
=
encoded_output
[
len
(
text_1
):]
if
len
(
encoded_output_1
)
==
1
:
encoded_output_1
=
encoded_output_1
*
len
(
encoded_output_2
)
output_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
encoded_output_1
,
encoded_output_2
)]
scores
=
[]
scorer
=
torch
.
nn
.
CosineSimilarity
(
0
)
for
embed_1
,
embed_2
in
output_pairs
:
pair_score
=
scorer
(
embed_1
.
outputs
.
data
,
embed_2
.
outputs
.
data
)
if
(
pad_token_id
:
=
getattr
(
tokenizer
,
"pad_token_id"
,
None
))
is
not
None
:
tokens
=
embed_1
.
prompt_token_ids
+
[
pad_token_id
]
+
embed_2
.
prompt_token_ids
else
:
tokens
=
embed_1
.
prompt_token_ids
+
embed_2
.
prompt_token_ids
scores
.
append
(
PoolingRequestOutput
(
request_id
=
f
"
{
embed_1
.
request_id
}
_
{
embed_2
.
request_id
}
"
,
outputs
=
pair_score
,
prompt_token_ids
=
tokens
,
finished
=
True
))
items
=
self
.
engine_class
.
validate_outputs
(
scores
,
PoolingRequestOutput
)
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
items
]
def
_cross_encoding_score
(
self
,
tokenizer
:
Union
[
AnyTokenizer
],
text_1
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]],
text_2
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]],
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
List
[
ScoringRequestOutput
]:
if
isinstance
(
tokenizer
,
MistralTokenizer
):
raise
ValueError
(
"Score API is only enabled for `--task embed or score`"
)
if
len
(
text_1
)
==
1
:
text_1
=
text_1
*
len
(
text_2
)
input_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
text_1
,
text_2
)]
pooling_params
=
PoolingParams
()
tokenization_kwargs
:
Dict
[
str
,
Any
]
=
{}
if
truncate_prompt_tokens
is
not
None
:
tokenization_kwargs
[
"truncation"
]
=
True
tokenization_kwargs
[
"max_length"
]
=
truncate_prompt_tokens
parsed_prompts
=
[]
for
q
,
t
in
input_pairs
:
prompt_inputs
=
tokenizer
(
text
=
q
,
text_pair
=
t
,
**
tokenization_kwargs
)
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_inputs
[
"input_ids"
],
token_type_ids
=
prompt_inputs
.
get
(
"token_type_ids"
))
parsed_prompts
.
append
(
engine_prompt
)
self
.
_validate_and_add_requests
(
prompts
=
parsed_prompts
,
params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
items
=
self
.
engine_class
.
validate_outputs
(
outputs
,
PoolingRequestOutput
)
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
items
]
def
score
(
self
,
text_1
:
Union
[
SingletonPrompt
,
Sequence
[
SingletonPrompt
]],
...
...
@@ -1003,25 +1149,20 @@ class LLM:
raise
ValueError
(
" "
.
join
(
messages
))
if
not
self
.
llm_engine
.
model_config
.
is_cross_encoder
:
raise
ValueError
(
"Your model does not support cross encoding"
)
if
self
.
llm_engine
.
model_config
.
task
!=
"score"
:
raise
ValueError
(
"Score API is only enabled for `--task score`"
)
tokenizer
=
self
.
llm_engine
.
get_tokenizer
()
if
isinstance
(
tokenizer
,
MistralTokenizer
):
if
self
.
llm_engine
.
model_config
.
task
not
in
(
"embed"
,
"score"
):
raise
ValueError
(
"
MistralTokenizer not support
ed
f
or
cross-encoding
"
)
"
Score API is only enabled for `--task emb
ed or
--task score`
"
)
# the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
# lists of tokens to the `text` and `text_pair` kwargs
tokenizer
=
self
.
llm_engine
.
get_tokenizer
()
def
ensure_str
(
prompt
:
SingletonPrompt
):
if
isinstance
(
prompt
,
dict
):
if
"multi_modal_data"
in
prompt
:
raise
ValueError
(
"Multi-modal prompt is not "
"supported for
cross en
co
d
ing"
)
"supported for
s
co
r
ing"
)
elif
"prompt_token_ids"
in
prompt
:
prompt
=
tokenizer
.
decode
(
cast
(
TokensPrompt
,
prompt
)[
"prompt_token_ids"
])
...
...
@@ -1047,40 +1188,15 @@ class LLM:
if
len
(
text_2
)
==
0
:
raise
ValueError
(
"At least one text_pair element must be given"
)
if
len
(
text_1
)
==
1
:
text_1
=
text_1
*
len
(
text_2
)
input_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
text_1
,
text_2
)]
pooling_params
=
PoolingParams
()
tokenization_kwargs
:
Dict
[
str
,
Any
]
=
{}
if
truncate_prompt_tokens
is
not
None
:
tokenization_kwargs
[
"truncation"
]
=
True
tokenization_kwargs
[
"max_length"
]
=
truncate_prompt_tokens
parsed_prompts
=
[]
for
q
,
t
in
input_pairs
:
prompt_inputs
=
tokenizer
(
text
=
q
,
text_pair
=
t
,
**
tokenization_kwargs
)
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_inputs
[
"input_ids"
],
token_type_ids
=
prompt_inputs
.
get
(
"token_type_ids"
))
parsed_prompts
.
append
(
engine_prompt
)
self
.
_validate_and_add_requests
(
prompts
=
parsed_prompts
,
params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
items
=
self
.
engine_class
.
validate_outputs
(
outputs
,
PoolingRequestOutput
)
return
[
ScoringRequestOutput
.
from_base
(
item
)
for
item
in
items
]
if
self
.
llm_engine
.
model_config
.
is_cross_encoder
:
return
self
.
_cross_encoding_score
(
tokenizer
,
text_1
,
text_2
,
truncate_prompt_tokens
,
use_tqdm
,
lora_request
,
prompt_adapter_request
)
else
:
return
self
.
_embedding_score
(
tokenizer
,
text_1
,
text_2
,
truncate_prompt_tokens
,
use_tqdm
,
lora_request
,
prompt_adapter_request
)
def
start_profile
(
self
)
->
None
:
self
.
llm_engine
.
start_profile
()
...
...
@@ -1088,6 +1204,36 @@ class LLM:
def
stop_profile
(
self
)
->
None
:
self
.
llm_engine
.
stop_profile
()
def
reset_prefix_cache
(
self
)
->
bool
:
return
self
.
llm_engine
.
reset_prefix_cache
()
def
sleep
(
self
,
level
:
int
=
1
):
"""
Put the engine to sleep. The engine should not process any requests.
The caller should guarantee that no requests are being processed
during the sleep period, before `wake_up` is called.
:param level: The sleep level. Level 1 sleep will offload the model
weights and discard the kv cache. The content of kv cache is
forgotten. Level 1 sleep is good for sleeping and waking up the
engine to run the same model again. The model weights are backed
up in CPU memory. Please make sure there's enough CPU memory to
store the model weights. Level 2 sleep will discard both the model
weights and the kv cache. The content of both the model weights
and kv cache is forgotten. Level 2 sleep is good for sleeping and
waking up the engine to run a different model or update the model,
where previous model weights are not needed. It reduces CPU memory
pressure.
"""
self
.
reset_prefix_cache
()
self
.
llm_engine
.
sleep
(
level
=
level
)
def
wake_up
(
self
):
"""
Wake up the engine from sleep mode. See the :meth:`sleep` method
for more details."""
self
.
llm_engine
.
wake_up
()
# LEGACY
def
_convert_v1_inputs
(
self
,
...
...
vllm/entrypoints/openai/api_server.py
View file @
afd0da21
import
asyncio
import
atexit
import
gc
import
importlib
import
inspect
import
multiprocessing
...
...
@@ -7,16 +8,17 @@ import os
import
re
import
signal
import
socket
import
sys
import
tempfile
import
uuid
from
argparse
import
Namespace
from
contextlib
import
asynccontextmanager
from
functools
import
partial
from
http
import
HTTPStatus
from
typing
import
AsyncIterator
,
Optional
,
Set
,
Tuple
from
typing
import
AsyncIterator
,
Dict
,
Optional
,
Set
,
Tuple
,
Union
import
uvloop
from
fastapi
import
APIRouter
,
FastAPI
,
Request
from
fastapi
import
APIRouter
,
FastAPI
,
HTTPException
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
...
...
@@ -44,22 +46,31 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionResponse
,
DetokenizeRequest
,
DetokenizeResponse
,
EmbeddingChatRequest
,
EmbeddingCompletionRequest
,
EmbeddingRequest
,
EmbeddingResponse
,
EmbeddingResponseData
,
ErrorResponse
,
LoadLoraAdapterRequest
,
PoolingChatRequest
,
PoolingCompletionRequest
,
PoolingRequest
,
PoolingResponse
,
RerankRequest
,
RerankResponse
,
ScoreRequest
,
ScoreResponse
,
TokenizeRequest
,
TokenizeResponse
,
UnloadLoraAdapterRequest
)
from
vllm.entrypoints.openai.reasoning_parsers
import
ReasoningParserManager
# yapf: enable
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
,
OpenAIServing
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_models
import
(
BaseModelPath
,
OpenAIServingModels
)
from
vllm.entrypoints.openai.serving_pooling
import
OpenAIServingPooling
from
vllm.entrypoints.openai.serving_rerank
import
JinaAIServingRerank
from
vllm.entrypoints.openai.serving_score
import
OpenAIServingScores
from
vllm.entrypoints.openai.serving_tokenization
import
(
OpenAIServingTokenization
)
...
...
@@ -97,6 +108,11 @@ async def lifespan(app: FastAPI):
task
.
add_done_callback
(
_running_tasks
.
remove
)
else
:
task
=
None
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
gc
.
collect
()
gc
.
freeze
()
try
:
yield
finally
:
...
...
@@ -133,32 +149,21 @@ async def build_async_engine_client_from_engine_args(
Returns the Client or None if the creation failed.
"""
# Fall back
# TODO: fill out feature matrix.
# AsyncLLMEngine.
if
(
MQLLMEngineClient
.
is_unsupported_config
(
engine_args
)
or
envs
.
VLLM_USE_V1
or
disable_frontend_multiprocessing
):
engine_config
=
engine_args
.
create_engine_config
(
UsageContext
.
OPENAI_API_SERVER
)
uses_ray
=
getattr
(
AsyncLLMEngine
.
_get_executor_cls
(
engine_config
),
"uses_ray"
,
False
)
build_engine
=
partial
(
AsyncLLMEngine
.
from_engine_args
,
engine_client
:
Optional
[
EngineClient
]
=
None
try
:
engine_client
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
=
engine_args
,
engine_config
=
engine_config
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
)
if
uses_ray
:
# Must run in main thread with ray for its signal handlers to work
engine_client
=
build_engine
()
else
:
engine_client
=
await
asyncio
.
get_running_loop
().
run_in_executor
(
None
,
build_engine
)
yield
engine_client
if
hasattr
(
engine_client
,
"shutdown"
):
finally
:
if
engine_client
and
hasattr
(
engine_client
,
"shutdown"
):
engine_client
.
shutdown
()
return
#
Otherwise, use the multiprocessing Async
LLMEngine.
#
MQ
LLMEngine.
else
:
if
"PROMETHEUS_MULTIPROC_DIR"
not
in
os
.
environ
:
# Make TemporaryDirectory for prometheus multiprocessing
...
...
@@ -280,6 +285,10 @@ def base(request: Request) -> OpenAIServing:
return
tokenization
(
request
)
def
models
(
request
:
Request
)
->
OpenAIServingModels
:
return
request
.
app
.
state
.
openai_serving_models
def
chat
(
request
:
Request
)
->
Optional
[
OpenAIServingChat
]:
return
request
.
app
.
state
.
openai_serving_chat
...
...
@@ -300,6 +309,10 @@ def score(request: Request) -> Optional[OpenAIServingScores]:
return
request
.
app
.
state
.
openai_serving_scores
def
rerank
(
request
:
Request
)
->
Optional
[
JinaAIServingRerank
]:
return
request
.
app
.
state
.
jinaai_serving_reranking
def
tokenization
(
request
:
Request
)
->
OpenAIServingTokenization
:
return
request
.
app
.
state
.
openai_serving_tokenization
...
...
@@ -315,6 +328,12 @@ async def health(raw_request: Request) -> Response:
return
Response
(
status_code
=
200
)
@
router
.
api_route
(
"/ping"
,
methods
=
[
"GET"
,
"POST"
])
async
def
ping
(
raw_request
:
Request
)
->
Response
:
"""Ping check. Endpoint required for SageMaker"""
return
await
health
(
raw_request
)
@
router
.
post
(
"/tokenize"
)
@
with_cancellation
async
def
tokenize
(
request
:
TokenizeRequest
,
raw_request
:
Request
):
...
...
@@ -347,10 +366,10 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
@
router
.
get
(
"/v1/models"
)
async
def
show_available_models
(
raw_request
:
Request
):
handler
=
base
(
raw_request
)
handler
=
models
(
raw_request
)
models
=
await
handler
.
show_available_models
()
return
JSONResponse
(
content
=
models
.
model_dump
())
models
_
=
await
handler
.
show_available_models
()
return
JSONResponse
(
content
=
models
_
.
model_dump
())
@
router
.
get
(
"/version"
)
...
...
@@ -414,6 +433,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
"use the Pooling API (`/pooling`) instead."
)
res
=
await
fallback_handler
.
create_pooling
(
request
,
raw_request
)
generator
:
Union
[
ErrorResponse
,
EmbeddingResponse
]
if
isinstance
(
res
,
PoolingResponse
):
generator
=
EmbeddingResponse
(
id
=
res
.
id
,
...
...
@@ -488,6 +509,103 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return
await
create_score
(
request
,
raw_request
)
@
router
.
post
(
"/rerank"
)
@
with_cancellation
async
def
do_rerank
(
request
:
RerankRequest
,
raw_request
:
Request
):
handler
=
rerank
(
raw_request
)
if
handler
is
None
:
return
base
(
raw_request
).
create_error_response
(
message
=
"The model does not support Rerank (Score) API"
)
generator
=
await
handler
.
do_rerank
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
elif
isinstance
(
generator
,
RerankResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
assert_never
(
generator
)
@
router
.
post
(
"/v1/rerank"
)
@
with_cancellation
async
def
do_rerank_v1
(
request
:
RerankRequest
,
raw_request
:
Request
):
logger
.
warning_once
(
"To indicate that the rerank API is not part of the standard OpenAI"
" API, we have located it at `/rerank`. Please update your client"
"accordingly. (Note: Conforms to JinaAI rerank API)"
)
return
await
do_rerank
(
request
,
raw_request
)
@
router
.
post
(
"/v2/rerank"
)
@
with_cancellation
async
def
do_rerank_v2
(
request
:
RerankRequest
,
raw_request
:
Request
):
return
await
do_rerank
(
request
,
raw_request
)
TASK_HANDLERS
:
Dict
[
str
,
Dict
[
str
,
tuple
]]
=
{
"generate"
:
{
"messages"
:
(
ChatCompletionRequest
,
create_chat_completion
),
"default"
:
(
CompletionRequest
,
create_completion
),
},
"embed"
:
{
"messages"
:
(
EmbeddingChatRequest
,
create_embedding
),
"default"
:
(
EmbeddingCompletionRequest
,
create_embedding
),
},
"score"
:
{
"default"
:
(
RerankRequest
,
do_rerank
)
},
"rerank"
:
{
"default"
:
(
RerankRequest
,
do_rerank
)
},
"reward"
:
{
"messages"
:
(
PoolingChatRequest
,
create_pooling
),
"default"
:
(
PoolingCompletionRequest
,
create_pooling
),
},
"classify"
:
{
"messages"
:
(
PoolingChatRequest
,
create_pooling
),
"default"
:
(
PoolingCompletionRequest
,
create_pooling
),
},
}
if
envs
.
VLLM_SERVER_DEV_MODE
:
@
router
.
post
(
"/reset_prefix_cache"
)
async
def
reset_prefix_cache
(
raw_request
:
Request
):
"""
Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server.
"""
logger
.
info
(
"Resetting prefix cache..."
)
await
engine_client
(
raw_request
).
reset_prefix_cache
()
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/invocations"
)
async
def
invocations
(
raw_request
:
Request
):
"""
For SageMaker, routes requests to other handlers based on model `task`.
"""
body
=
await
raw_request
.
json
()
task
=
raw_request
.
app
.
state
.
task
if
task
not
in
TASK_HANDLERS
:
raise
HTTPException
(
status_code
=
400
,
detail
=
f
"Unsupported task: '
{
task
}
' for '/invocations'. "
f
"Expected one of
{
set
(
TASK_HANDLERS
.
keys
())
}
"
)
handler_config
=
TASK_HANDLERS
[
task
]
if
"messages"
in
body
:
request_model
,
handler
=
handler_config
[
"messages"
]
else
:
request_model
,
handler
=
handler_config
[
"default"
]
# this is required since we lose the FastAPI automatic casting
request
=
request_model
.
model_validate
(
body
)
return
await
handler
(
request
,
raw_request
)
if
envs
.
VLLM_TORCH_PROFILER_DIR
:
logger
.
warning
(
"Torch Profiler is enabled in the API server. This should ONLY be "
...
...
@@ -516,9 +634,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
@
router
.
post
(
"/v1/load_lora_adapter"
)
async
def
load_lora_adapter
(
request
:
LoadLoraAdapterRequest
,
raw_request
:
Request
):
for
route
in
[
chat
,
completion
,
embedding
]:
handler
=
route
(
raw_request
)
if
handler
is
not
None
:
handler
=
models
(
raw_request
)
response
=
await
handler
.
load_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
...
...
@@ -529,9 +645,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
@
router
.
post
(
"/v1/unload_lora_adapter"
)
async
def
unload_lora_adapter
(
request
:
UnloadLoraAdapterRequest
,
raw_request
:
Request
):
for
route
in
[
chat
,
completion
,
embedding
]:
handler
=
route
(
raw_request
)
if
handler
is
not
None
:
handler
=
models
(
raw_request
)
response
=
await
handler
.
unload_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
...
...
@@ -602,7 +716,7 @@ def build_app(args: Namespace) -> FastAPI:
module_path
,
object_name
=
middleware
.
rsplit
(
"."
,
1
)
imported
=
getattr
(
importlib
.
import_module
(
module_path
),
object_name
)
if
inspect
.
isclass
(
imported
):
app
.
add_middleware
(
imported
)
app
.
add_middleware
(
imported
)
# type: ignore[arg-type]
elif
inspect
.
iscoroutinefunction
(
imported
):
app
.
middleware
(
"http"
)(
imported
)
else
:
...
...
@@ -612,7 +726,7 @@ def build_app(args: Namespace) -> FastAPI:
return
app
def
init_app_state
(
async
def
init_app_state
(
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
state
:
State
,
...
...
@@ -639,34 +753,40 @@ def init_app_state(
resolved_chat_template
=
load_chat_template
(
args
.
chat_template
)
logger
.
info
(
"Using supplied chat template:
\n
%s"
,
resolved_chat_template
)
state
.
openai_serving_models
=
OpenAIServingModels
(
engine_client
=
engine_client
,
model_config
=
model_config
,
base_model_paths
=
base_model_paths
,
lora_modules
=
args
.
lora_modules
,
prompt_adapters
=
args
.
prompt_adapters
,
)
await
state
.
openai_serving_models
.
init_static_loras
()
state
.
openai_serving_chat
=
OpenAIServingChat
(
engine_client
,
model_config
,
base
_model
_path
s
,
state
.
openai_serving
_models
,
args
.
response_role
,
lora_modules
=
args
.
lora_modules
,
prompt_adapters
=
args
.
prompt_adapters
,
request_logger
=
request_logger
,
chat_template
=
resolved_chat_template
,
chat_template_content_format
=
args
.
chat_template_content_format
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
enable_auto_tools
=
args
.
enable_auto_tool_choice
,
tool_parser
=
args
.
tool_call_parser
,
enable_reasoning
=
args
.
enable_reasoning
,
reasoning_parser
=
args
.
reasoning_parser
,
enable_prompt_tokens_details
=
args
.
enable_prompt_tokens_details
,
)
if
model_config
.
runner_type
==
"generate"
else
None
state
.
openai_serving_completion
=
OpenAIServingCompletion
(
engine_client
,
model_config
,
base_model_paths
,
lora_modules
=
args
.
lora_modules
,
prompt_adapters
=
args
.
prompt_adapters
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
)
if
model_config
.
runner_type
==
"generate"
else
None
state
.
openai_serving_pooling
=
OpenAIServingPooling
(
engine_client
,
model_config
,
base
_model
_path
s
,
state
.
openai_serving
_models
,
request_logger
=
request_logger
,
chat_template
=
resolved_chat_template
,
chat_template_content_format
=
args
.
chat_template_content_format
,
...
...
@@ -674,7 +794,7 @@ def init_app_state(
state
.
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine_client
,
model_config
,
base
_model
_path
s
,
state
.
openai_serving
_models
,
request_logger
=
request_logger
,
chat_template
=
resolved_chat_template
,
chat_template_content_format
=
args
.
chat_template_content_format
,
...
...
@@ -682,18 +802,24 @@ def init_app_state(
state
.
openai_serving_scores
=
OpenAIServingScores
(
engine_client
,
model_config
,
base_model_paths
,
state
.
openai_serving_models
,
request_logger
=
request_logger
)
if
model_config
.
task
==
"score"
else
None
state
.
jinaai_serving_reranking
=
JinaAIServingRerank
(
engine_client
,
model_config
,
state
.
openai_serving_models
,
request_logger
=
request_logger
)
if
model_config
.
task
==
"score"
else
None
state
.
openai_serving_tokenization
=
OpenAIServingTokenization
(
engine_client
,
model_config
,
base_model_paths
,
lora_modules
=
args
.
lora_modules
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
chat_template
=
resolved_chat_template
,
chat_template_content_format
=
args
.
chat_template_content_format
,
)
state
.
task
=
model_config
.
task
def
create_server_socket
(
addr
:
Tuple
[
str
,
int
])
->
socket
.
socket
:
...
...
@@ -715,11 +841,18 @@ async def run_server(args, **uvicorn_kwargs) -> None:
if
args
.
tool_parser_plugin
and
len
(
args
.
tool_parser_plugin
)
>
3
:
ToolParserManager
.
import_tool_parser
(
args
.
tool_parser_plugin
)
valid
e
_tool_parses
=
ToolParserManager
.
tool_parsers
.
keys
()
valid_tool_parses
=
ToolParserManager
.
tool_parsers
.
keys
()
if
args
.
enable_auto_tool_choice
\
and
args
.
tool_call_parser
not
in
valid
e
_tool_parses
:
and
args
.
tool_call_parser
not
in
valid_tool_parses
:
raise
KeyError
(
f
"invalid tool call parser:
{
args
.
tool_call_parser
}
"
f
"(chose from {{
{
','
.
join
(
valide_tool_parses
)
}
}})"
)
f
"(chose from {{
{
','
.
join
(
valid_tool_parses
)
}
}})"
)
valid_reasoning_parses
=
ReasoningParserManager
.
reasoning_parsers
.
keys
()
if
args
.
enable_reasoning
\
and
args
.
reasoning_parser
not
in
valid_reasoning_parses
:
raise
KeyError
(
f
"invalid reasoning parser:
{
args
.
reasoning_parser
}
"
f
"(chose from {{
{
','
.
join
(
valid_reasoning_parses
)
}
}})"
)
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
...
...
@@ -741,7 +874,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
app
=
build_app
(
args
)
model_config
=
await
engine_client
.
get_model_config
()
init_app_state
(
engine_client
,
model_config
,
app
.
state
,
args
)
await
init_app_state
(
engine_client
,
model_config
,
app
.
state
,
args
)
shutdown_task
=
await
serve_http
(
app
,
...
...
@@ -753,6 +886,8 @@ async def run_server(args, **uvicorn_kwargs) -> None:
ssl_certfile
=
args
.
ssl_certfile
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
ssl_cert_reqs
=
args
.
ssl_cert_reqs
,
# Workaround to work on macOS
fd
=
sock
.
fileno
()
if
sys
.
platform
.
startswith
(
"darwin"
)
else
None
,
**
uvicorn_kwargs
,
)
...
...
vllm/entrypoints/openai/cli_args.py
View file @
afd0da21
...
...
@@ -12,7 +12,8 @@ from typing import List, Optional, Sequence, Union, get_args
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
nullable_str
from
vllm.entrypoints.chat_utils
import
(
ChatTemplateContentFormatOption
,
validate_chat_template
)
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
from
vllm.entrypoints.openai.reasoning_parsers
import
ReasoningParserManager
from
vllm.entrypoints.openai.serving_models
import
(
LoRAModulePath
,
PromptAdapterPath
)
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.utils
import
FlexibleArgumentParser
...
...
@@ -79,29 +80,29 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser
.
add_argument
(
"--host"
,
type
=
nullable_str
,
default
=
None
,
help
=
"
h
ost name"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"
p
ort number"
)
help
=
"
H
ost name
.
"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"
P
ort number
.
"
)
parser
.
add_argument
(
"--uvicorn-log-level"
,
type
=
str
,
default
=
"info"
,
choices
=
[
'debug'
,
'info'
,
'warning'
,
'error'
,
'critical'
,
'trace'
],
help
=
"
l
og level for uvicorn"
)
help
=
"
L
og level for uvicorn
.
"
)
parser
.
add_argument
(
"--allow-credentials"
,
action
=
"store_true"
,
help
=
"
a
llow credentials"
)
help
=
"
A
llow credentials
.
"
)
parser
.
add_argument
(
"--allowed-origins"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"
a
llowed origins"
)
help
=
"
A
llowed origins
.
"
)
parser
.
add_argument
(
"--allowed-methods"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"
a
llowed methods"
)
help
=
"
A
llowed methods
.
"
)
parser
.
add_argument
(
"--allowed-headers"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"
a
llowed headers"
)
help
=
"
A
llowed headers
.
"
)
parser
.
add_argument
(
"--api-key"
,
type
=
nullable_str
,
default
=
None
,
...
...
@@ -115,10 +116,10 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
action
=
LoRAParserAction
,
help
=
"LoRA module configurations in either 'name=path' format"
"or JSON format. "
"Example (old format): 'name=path' "
"Example (old format):
``
'name=path'
``
"
"Example (new format): "
"
'
{
\"
name
\"
:
\"
name
\"
,
\"
local_
path
\"
:
\"
path
\"
, "
"
\"
base_model_name
\"
:
\"
id
\"
}
'
"
)
"
``
{
\"
name
\"
:
\"
name
\"
,
\"
path
\"
:
\"
lora_
path
\"
, "
"
\"
base_model_name
\"
:
\"
id
\"
}
``
"
)
parser
.
add_argument
(
"--prompt-adapters"
,
type
=
nullable_str
,
...
...
@@ -132,7 +133,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default
=
None
,
help
=
"The file path to the chat template, "
"or the template in single-line form "
"for the specified model"
)
"for the specified model
.
"
)
parser
.
add_argument
(
'--chat-template-content-format'
,
type
=
str
,
...
...
@@ -141,38 +142,39 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help
=
'The format to render message content within a chat template.'
'
\n\n
'
'* "string" will render the content as a string. '
'Example: "Hello World"
\n
'
'Example:
``
"Hello World"
``
\n
'
'* "openai" will render the content as a list of dictionaries, '
'similar to OpenAI schema. '
'Example: [{"type": "text", "text": "Hello world!"}]'
)
'Example:
``
[{"type": "text", "text": "Hello world!"}]
``
'
)
parser
.
add_argument
(
"--response-role"
,
type
=
nullable_str
,
default
=
"assistant"
,
help
=
"The role name to return if "
"`request.add_generation_prompt=true`."
)
"`
`
request.add_generation_prompt=true`
`
."
)
parser
.
add_argument
(
"--ssl-keyfile"
,
type
=
nullable_str
,
default
=
None
,
help
=
"The file path to the SSL key file"
)
help
=
"The file path to the SSL key file
.
"
)
parser
.
add_argument
(
"--ssl-certfile"
,
type
=
nullable_str
,
default
=
None
,
help
=
"The file path to the SSL cert file"
)
help
=
"The file path to the SSL cert file
.
"
)
parser
.
add_argument
(
"--ssl-ca-certs"
,
type
=
nullable_str
,
default
=
None
,
help
=
"The CA certificates file"
)
help
=
"The CA certificates file
.
"
)
parser
.
add_argument
(
"--ssl-cert-reqs"
,
type
=
int
,
default
=
int
(
ssl
.
CERT_NONE
),
help
=
"Whether client certificate is required (see stdlib ssl module's)"
help
=
"Whether client certificate is required (see stdlib ssl module's)
.
"
)
parser
.
add_argument
(
"--root-path"
,
type
=
nullable_str
,
default
=
None
,
help
=
"FastAPI root_path when app is behind a path based routing proxy"
)
help
=
"FastAPI root_path when app is behind a path based routing proxy."
)
parser
.
add_argument
(
"--middleware"
,
type
=
nullable_str
,
...
...
@@ -182,15 +184,15 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server "
"using @app.middleware('http'). "
"using
``
@app.middleware('http')
``
. "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). "
)
"using
``
app.add_middleware()
``
. "
)
parser
.
add_argument
(
"--return-tokens-as-token-ids"
,
action
=
"store_true"
,
help
=
"When --max-logprobs is specified, represents single tokens
as
"
"strings of the form 'token_id:{token_id}' so that tokens
that
"
"are not JSON-encodable can be identified."
)
help
=
"When
``
--max-logprobs
``
is specified, represents single tokens "
"
as
strings of the form 'token_id:{token_id}' so that tokens "
"
that
are not JSON-encodable can be identified."
)
parser
.
add_argument
(
"--disable-frontend-multiprocessing"
,
action
=
"store_true"
,
...
...
@@ -205,9 +207,25 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--enable-auto-tool-choice"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Enable auto tool choice for supported models. Use "
"``--tool-call-parser`` to specify which parser to use."
)
parser
.
add_argument
(
"--enable-reasoning"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to enable reasoning_content for the model. "
"If enabled, the model will be able to generate reasoning content."
)
valid_reasoning_parsers
=
ReasoningParserManager
.
reasoning_parsers
.
keys
()
parser
.
add_argument
(
"--reasoning-parser"
,
type
=
str
,
metavar
=
"{"
+
","
.
join
(
valid_reasoning_parsers
)
+
"}"
,
default
=
None
,
help
=
"Enable auto tool choice for supported models. Use --tool-call-parser"
" to specify which parser to use"
)
"Select the reasoning parser depending on the model that you're using."
" This is used to parse the reasoning content into OpenAI API "
"format. Required for ``--enable-reasoning``."
)
valid_tool_parsers
=
ToolParserManager
.
tool_parsers
.
keys
()
parser
.
add_argument
(
...
...
@@ -219,7 +237,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help
=
"Select the tool call parser depending on the model that you're using."
" This is used to parse the model-generated tool call into OpenAI API "
"format. Required for --enable-auto-tool-choice."
)
"format. Required for
``
--enable-auto-tool-choice
``
."
)
parser
.
add_argument
(
"--tool-parser-plugin"
,
...
...
@@ -228,7 +246,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help
=
"Special the tool parser plugin write to parse the model-generated tool"
" into OpenAI API format, the name register in this plugin can be used "
"in --tool-call-parser."
)
"in
``
--tool-call-parser
``
."
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
...
...
@@ -243,7 +261,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--disable-fastapi-docs"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
help
=
"Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint
.
"
)
parser
.
add_argument
(
"--enable-prompt-tokens-details"
,
...
...
@@ -267,6 +285,18 @@ def validate_parsed_serve_args(args: argparse.Namespace):
raise
TypeError
(
"Error: --enable-auto-tool-choice requires "
"--tool-call-parser"
)
# Enable reasoning needs a reasoning parser to be valid
if
args
.
enable_reasoning
and
not
args
.
reasoning_parser
:
raise
TypeError
(
"Error: --enable-reasoning requires "
"--reasoning-parser"
)
# Ref https://api-docs.deepseek.com/guides/reasoning_model
# tool call and reasoning cannot be enabled at the same time.
if
args
.
enable_auto_tool_choice
and
args
.
enable_reasoning
:
raise
TypeError
(
"Error: --enable-auto-tool-choice and "
"--enable-reasoning cannot be enabled at the same time"
)
def
create_parser_for_docs
()
->
FlexibleArgumentParser
:
parser_for_docs
=
FlexibleArgumentParser
(
...
...
vllm/entrypoints/openai/protocol.py
View file @
afd0da21
...
...
@@ -3,10 +3,11 @@
import
re
import
time
from
argparse
import
Namespace
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
ClassVar
,
Dict
,
List
,
Literal
,
Optional
,
Set
,
Union
import
torch
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
model_validator
from
pydantic
import
(
BaseModel
,
ConfigDict
,
Field
,
TypeAdapter
,
ValidationInfo
,
field_validator
,
model_validator
)
from
typing_extensions
import
Annotated
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
...
...
@@ -42,24 +43,32 @@ class OpenAIBaseModel(BaseModel):
# OpenAI API does allow extra fields
model_config
=
ConfigDict
(
extra
=
"allow"
)
@
model_validator
(
mode
=
"before"
)
# Cache class field names
field_names
:
ClassVar
[
Optional
[
Set
[
str
]]]
=
None
@
model_validator
(
mode
=
"wrap"
)
@
classmethod
def
__log_extra_fields__
(
cls
,
data
):
if
isinstance
(
data
,
dict
):
def
__log_extra_fields__
(
cls
,
data
,
handler
):
result
=
handler
(
data
)
if
not
isinstance
(
data
,
dict
):
return
result
field_names
=
cls
.
field_names
if
field_names
is
None
:
# Get all class field names and their potential aliases
field_names
=
set
()
for
field_name
,
field
in
cls
.
model_fields
.
items
():
field_names
.
add
(
field_name
)
if
hasattr
(
field
,
'alias'
)
and
field
.
alias
:
field_names
.
add
(
field
.
alias
)
if
alias
:
=
getattr
(
field
,
'alias'
,
None
):
field_names
.
add
(
alias
)
cls
.
field_names
=
field_names
# Compare against both field names and aliases
extra_fields
=
data
.
keys
()
-
field_names
if
extra_fields
:
if
any
(
k
not
in
field_names
for
k
in
data
):
logger
.
warning
(
"The following fields were present in the request "
"but ignored: %s"
,
extra_fields
)
return
data
"but ignored: %s"
,
data
.
keys
()
-
field_names
)
return
result
class
ErrorResponse
(
OpenAIBaseModel
):
...
...
@@ -372,13 +381,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
)
->
BeamSearchParams
:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens
=
self
.
max_completion_tokens
or
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
n
=
self
.
n
if
self
.
n
is
not
None
else
1
# Use minimum of context window, user request & server limit.
max_tokens
=
min
(
val
for
val
in
(
default_max_tokens
,
max_tokens
,
default_sampling_params
.
get
(
"max_tokens"
,
None
))
if
val
is
not
None
)
if
(
temperature
:
=
self
.
temperature
)
is
None
:
temperature
=
default_sampling_params
.
get
(
"temperature"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
...
...
@@ -398,11 +411,16 @@ class ChatCompletionRequest(OpenAIBaseModel):
default_sampling_params
:
Optional
[
dict
]
=
None
)
->
SamplingParams
:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens
=
self
.
max_completion_tokens
or
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
# Use minimum of context window, user request & server limit.
max_tokens
=
min
(
val
for
val
in
(
default_max_tokens
,
max_tokens
,
default_sampling_params
.
get
(
"max_tokens"
,
None
))
if
val
is
not
None
)
# Default parameters
if
(
repetition_penalty
:
=
self
.
repetition_penalty
)
is
None
:
repetition_penalty
=
default_sampling_params
.
get
(
...
...
@@ -732,13 +750,17 @@ class CompletionRequest(OpenAIBaseModel):
default_sampling_params
:
Optional
[
dict
]
=
None
)
->
BeamSearchParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
n
=
self
.
n
if
self
.
n
is
not
None
else
1
# Use minimum of context window, user request & server limit.
max_tokens
=
min
(
val
for
val
in
(
default_max_tokens
,
max_tokens
,
default_sampling_params
.
get
(
"max_tokens"
,
None
))
if
val
is
not
None
)
if
(
temperature
:
=
self
.
temperature
)
is
None
:
temperature
=
default_sampling_params
.
get
(
"temperature"
,
1.0
)
...
...
@@ -756,11 +778,16 @@ class CompletionRequest(OpenAIBaseModel):
logits_processor_pattern
:
Optional
[
str
],
default_sampling_params
:
Optional
[
dict
]
=
None
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
# Use minimum of context window, user request & server limit.
max_tokens
=
min
(
val
for
val
in
(
default_max_tokens
,
max_tokens
,
default_sampling_params
.
get
(
"max_tokens"
,
None
))
if
val
is
not
None
)
# Default parameters
if
(
repetition_penalty
:
=
self
.
repetition_penalty
)
is
None
:
repetition_penalty
=
default_sampling_params
.
get
(
...
...
@@ -992,6 +1019,52 @@ class ScoreRequest(OpenAIBaseModel):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
class
RerankRequest
(
OpenAIBaseModel
):
model
:
str
query
:
str
documents
:
List
[
str
]
top_n
:
int
=
Field
(
default_factory
=
lambda
:
0
)
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
# doc: begin-rerank-pooling-params
additional_data
:
Optional
[
Any
]
=
None
# doc: end-rerank-pooling-params
# doc: begin-rerank-extra-params
priority
:
int
=
Field
(
default
=
0
,
description
=
(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
))
# doc: end-rerank-extra-params
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
class
RerankDocument
(
BaseModel
):
text
:
str
class
RerankResult
(
BaseModel
):
index
:
int
document
:
RerankDocument
relevance_score
:
float
class
RerankUsage
(
BaseModel
):
total_tokens
:
int
class
RerankResponse
(
OpenAIBaseModel
):
id
:
str
model
:
str
usage
:
RerankUsage
results
:
List
[
RerankResult
]
class
CompletionLogProbs
(
OpenAIBaseModel
):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
...
...
@@ -1130,6 +1203,7 @@ class ExtractedToolCallInformation(BaseModel):
class
ChatMessage
(
OpenAIBaseModel
):
role
:
str
reasoning_content
:
Optional
[
str
]
=
None
content
:
Optional
[
str
]
=
None
tool_calls
:
List
[
ToolCall
]
=
Field
(
default_factory
=
list
)
...
...
@@ -1171,6 +1245,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
class
DeltaMessage
(
OpenAIBaseModel
):
role
:
Optional
[
str
]
=
None
content
:
Optional
[
str
]
=
None
reasoning_content
:
Optional
[
str
]
=
None
tool_calls
:
List
[
DeltaToolCall
]
=
Field
(
default_factory
=
list
)
...
...
@@ -1211,7 +1286,21 @@ class BatchRequestInput(OpenAIBaseModel):
url
:
str
# The parameters of the request.
body
:
Union
[
ChatCompletionRequest
,
EmbeddingRequest
]
body
:
Union
[
ChatCompletionRequest
,
EmbeddingRequest
,
ScoreRequest
]
@
field_validator
(
'body'
,
mode
=
'plain'
)
@
classmethod
def
check_type_for_url
(
cls
,
value
:
Any
,
info
:
ValidationInfo
):
# Use url to disambiguate models
url
=
info
.
data
[
'url'
]
if
url
==
"/v1/chat/completions"
:
return
ChatCompletionRequest
.
model_validate
(
value
)
if
url
==
"/v1/embeddings"
:
return
TypeAdapter
(
EmbeddingRequest
).
validate_python
(
value
)
if
url
==
"/v1/score"
:
return
ScoreRequest
.
model_validate
(
value
)
return
TypeAdapter
(
Union
[
ChatCompletionRequest
,
EmbeddingRequest
,
ScoreRequest
]).
validate_python
(
value
)
class
BatchResponseData
(
OpenAIBaseModel
):
...
...
@@ -1222,7 +1311,8 @@ class BatchResponseData(OpenAIBaseModel):
request_id
:
str
# The body of the response.
body
:
Optional
[
Union
[
ChatCompletionResponse
,
EmbeddingResponse
]]
=
None
body
:
Optional
[
Union
[
ChatCompletionResponse
,
EmbeddingResponse
,
ScoreResponse
]]
=
None
class
BatchRequestOutput
(
OpenAIBaseModel
):
...
...
vllm/entrypoints/openai/reasoning_parsers/__init__.py
0 → 100644
View file @
afd0da21
from
.abs_reasoning_parsers
import
ReasoningParser
,
ReasoningParserManager
from
.deepseek_r1_reasoning_parser
import
DeepSeekR1ReasoningParser
__all__
=
[
"ReasoningParser"
,
"ReasoningParserManager"
,
"DeepSeekR1ReasoningParser"
]
vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py
0 → 100644
View file @
afd0da21
import
os
from
functools
import
cached_property
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaMessage
)
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
import_from_path
,
is_list_of
logger
=
init_logger
(
__name__
)
class
ReasoningParser
:
"""
Abstract reasoning parser class that should not be used directly.
Provided and methods should be used in derived classes.
It is used to extract reasoning content from the model output.
"""
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
self
.
model_tokenizer
=
tokenizer
@
cached_property
def
vocab
(
self
)
->
Dict
[
str
,
int
]:
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
# whereas all tokenizers have .get_vocab()
return
self
.
model_tokenizer
.
get_vocab
()
def
extract_reasoning_content
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
Tuple
[
Optional
[
str
],
Optional
[
str
]]:
"""
Extract reasoning content from a complete model-generated string.
Used for non-streaming responses where we have the entire model response
available before sending to the client.
Parameters:
model_output: str
The model-generated string to extract reasoning content from.
request: ChatCompletionRequest
The request object that was used to generate the model_output.
Returns:
Tuple[Optional[str], Optional[str]]
A tuple containing the reasoning content and the content.
"""
raise
NotImplementedError
(
"AbstractReasoningParser.extract_reasoning_calls "
"has not been implemented!"
)
def
extract_reasoning_content_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
)
->
Union
[
DeltaMessage
,
None
]:
"""
Instance method that should be implemented for extracting reasoning
from an incomplete response; for use when handling reasoning calls and
streaming. Has to be an instance method because it requires state -
the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor)
"""
raise
NotImplementedError
(
"AbstractReasoningParser.extract_reasoning_content_streaming "
"has not been implemented!"
)
class
ReasoningParserManager
:
reasoning_parsers
:
Dict
[
str
,
Type
]
=
{}
@
classmethod
def
get_reasoning_parser
(
cls
,
name
)
->
Type
:
"""
Get reasoning parser by name which is registered by `register_module`.
Raise a KeyError exception if the name is not registered.
"""
if
name
in
cls
.
reasoning_parsers
:
return
cls
.
reasoning_parsers
[
name
]
raise
KeyError
(
f
"reasoning helper: '
{
name
}
' not found in "
"reasoning_parsers"
)
@
classmethod
def
_register_module
(
cls
,
module
:
Type
,
module_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
force
:
bool
=
True
)
->
None
:
if
not
issubclass
(
module
,
ReasoningParser
):
raise
TypeError
(
"module must be subclass of ReasoningParser, "
f
"but got
{
type
(
module
)
}
"
)
if
module_name
is
None
:
module_name
=
module
.
__name__
if
isinstance
(
module_name
,
str
):
module_name
=
[
module_name
]
for
name
in
module_name
:
if
not
force
and
name
in
cls
.
reasoning_parsers
:
existed_module
=
cls
.
reasoning_parsers
[
name
]
raise
KeyError
(
f
"
{
name
}
is already registered "
f
"at
{
existed_module
.
__module__
}
"
)
cls
.
reasoning_parsers
[
name
]
=
module
@
classmethod
def
register_module
(
cls
,
name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
force
:
bool
=
True
,
module
:
Union
[
Type
,
None
]
=
None
)
->
Union
[
type
,
Callable
]:
"""
Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not
None).
"""
if
not
isinstance
(
force
,
bool
):
raise
TypeError
(
f
"force must be a boolean, but got
{
type
(
force
)
}
"
)
# raise the error ahead of time
if
not
(
name
is
None
or
isinstance
(
name
,
str
)
or
is_list_of
(
name
,
str
)):
raise
TypeError
(
"name must be None, an instance of str, or a sequence of str, "
f
"but got
{
type
(
name
)
}
"
)
# use it as a normal method: x.register_module(module=SomeClass)
if
module
is
not
None
:
cls
.
_register_module
(
module
=
module
,
module_name
=
name
,
force
=
force
)
return
module
# use it as a decorator: @x.register_module()
def
_register
(
module
):
cls
.
_register_module
(
module
=
module
,
module_name
=
name
,
force
=
force
)
return
module
return
_register
@
classmethod
def
import_reasoning_parser
(
cls
,
plugin_path
:
str
)
->
None
:
"""
Import a user-defined reasoning parser by the path
of the reasoning parser define file.
"""
module_name
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
plugin_path
))[
0
]
try
:
import_from_path
(
module_name
,
plugin_path
)
except
Exception
:
logger
.
exception
(
"Failed to load module '%s' from %s."
,
module_name
,
plugin_path
)
return
vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
0 → 100644
View file @
afd0da21
import
re
from
typing
import
Optional
,
Sequence
,
Tuple
,
Union
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaMessage
)
from
vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers
import
(
ReasoningParser
,
ReasoningParserManager
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
@
ReasoningParserManager
.
register_module
(
"deepseek_r1"
)
class
DeepSeekR1ReasoningParser
(
ReasoningParser
):
"""
Reasoning parser for DeepSeek R1 model.
The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning
text. This parser extracts the reasoning content from the model output.
"""
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizerBase
):
super
().
__init__
(
tokenizer
)
self
.
think_start_token
=
"<think>"
self
.
think_end_token
=
"</think>"
self
.
reasoning_regex
=
re
.
compile
(
rf
"
{
self
.
think_start_token
}
(.*?)
{
self
.
think_end_token
}
"
,
re
.
DOTALL
)
if
not
self
.
model_tokenizer
:
raise
ValueError
(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction."
)
self
.
think_start_token_id
=
self
.
vocab
.
get
(
self
.
think_start_token
)
self
.
think_end_token_id
=
self
.
vocab
.
get
(
self
.
think_end_token
)
if
(
self
.
think_start_token_id
is
None
or
self
.
think_end_token_id
is
None
):
raise
RuntimeError
(
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!"
)
def
extract_reasoning_content_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
)
->
Union
[
DeltaMessage
,
None
]:
"""
Extract reasoning content from a delta message.
Handles streaming output where previous + delta = current.
Uses token IDs for faster processing.
For text <think>abc</think>xyz:
- 'abc' goes to reasoning_content
- 'xyz' goes to content
"""
# Skip single special tokens
if
len
(
delta_token_ids
)
==
1
and
(
delta_token_ids
[
0
]
in
[
self
.
think_start_token_id
,
self
.
think_end_token_id
]):
return
None
if
self
.
think_start_token_id
in
previous_token_ids
:
if
self
.
think_end_token_id
in
delta_token_ids
:
# <think> in previous, </think> in delta,
# extract reasoning content
end_index
=
delta_text
.
find
(
self
.
think_end_token
)
reasoning_content
=
delta_text
[:
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
think_end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
)
elif
self
.
think_end_token_id
in
previous_token_ids
:
# <think> in previous, </think> in previous,
# reasoning content continues
return
DeltaMessage
(
content
=
delta_text
)
else
:
# <think> in previous, no </think> in previous or delta,
# reasoning content continues
return
DeltaMessage
(
reasoning_content
=
delta_text
)
elif
self
.
think_start_token_id
in
delta_token_ids
:
logger
.
info
(
delta_text
)
if
self
.
think_end_token_id
in
delta_token_ids
:
# <think> in delta, </think> in delta, extract reasoning content
start_index
=
delta_text
.
find
(
self
.
think_start_token
)
end_index
=
delta_text
.
find
(
self
.
think_end_token
)
reasoning_content
=
delta_text
[
start_index
+
len
(
self
.
think_start_token
):
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
think_end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
)
else
:
# <think> in delta, no </think> in delta,
# reasoning content continues
return
DeltaMessage
(
reasoning_content
=
delta_text
)
else
:
# No <think> in previous or delta, reasoning content continues.
return
DeltaMessage
(
content
=
delta_text
)
def
extract_reasoning_content
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
Tuple
[
Optional
[
str
],
Optional
[
str
]]:
# Check if the model output contains the <think> tokens.
if
(
self
.
think_start_token
not
in
model_output
or
self
.
think_end_token
not
in
model_output
):
return
None
,
model_output
else
:
# Use a regex to find the reasoning content
reasoning_content
=
self
.
reasoning_regex
.
findall
(
model_output
)[
0
]
# Remove the reasoning content from the model output
# Although deepseek's <think> token is always at the
# beginning of the line, we cannot guarantee that the
# other models will follow this convention.
# Therefore, we need to add :start_index.
start_index
=
model_output
.
find
(
self
.
think_start_token
)
if
start_index
!=
-
1
:
end_index
=
start_index
+
len
(
f
"
{
self
.
think_start_token
}{
reasoning_content
}{
self
.
think_end_token
}
"
)
model_output
=
model_output
[:
start_index
]
+
\
model_output
[
end_index
:]
if
len
(
model_output
)
==
0
:
return
reasoning_content
,
None
return
reasoning_content
,
model_output
vllm/entrypoints/openai/run_batch.py
View file @
afd0da21
...
...
@@ -16,11 +16,14 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput
,
BatchResponseData
,
ChatCompletionResponse
,
EmbeddingResponse
,
ErrorResponse
)
EmbeddingResponse
,
ErrorResponse
,
ScoreResponse
)
# yapf: enable
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
from
vllm.entrypoints.openai.serving_models
import
(
BaseModelPath
,
OpenAIServingModels
)
from
vllm.entrypoints.openai.serving_score
import
OpenAIServingScores
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
random_uuid
from
vllm.version
import
__version__
as
VLLM_VERSION
...
...
@@ -166,7 +169,8 @@ async def run_request(serving_engine_func: Callable,
tracker
:
BatchProgressTracker
)
->
BatchRequestOutput
:
response
=
await
serving_engine_func
(
request
.
body
)
if
isinstance
(
response
,
(
ChatCompletionResponse
,
EmbeddingResponse
)):
if
isinstance
(
response
,
(
ChatCompletionResponse
,
EmbeddingResponse
,
ScoreResponse
)):
batch_output
=
BatchRequestOutput
(
id
=
f
"vllm-
{
random_uuid
()
}
"
,
custom_id
=
request
.
custom_id
,
...
...
@@ -213,13 +217,18 @@ async def main(args):
request_logger
=
RequestLogger
(
max_log_len
=
args
.
max_log_len
)
# Create the openai serving objects.
openai_serving_models
=
OpenAIServingModels
(
engine_client
=
engine
,
model_config
=
model_config
,
base_model_paths
=
base_model_paths
,
lora_modules
=
None
,
prompt_adapters
=
None
,
)
openai_serving_chat
=
OpenAIServingChat
(
engine
,
model_config
,
base
_model
_path
s
,
openai_serving
_models
,
args
.
response_role
,
lora_modules
=
None
,
prompt_adapters
=
None
,
request_logger
=
request_logger
,
chat_template
=
None
,
chat_template_content_format
=
"auto"
,
...
...
@@ -228,11 +237,17 @@ async def main(args):
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine
,
model_config
,
base
_model
_path
s
,
openai_serving
_models
,
request_logger
=
request_logger
,
chat_template
=
None
,
chat_template_content_format
=
"auto"
,
)
if
model_config
.
task
==
"embed"
else
None
openai_serving_scores
=
(
OpenAIServingScores
(
engine
,
model_config
,
openai_serving_models
,
request_logger
=
request_logger
,
)
if
model_config
.
task
==
"score"
else
None
)
tracker
=
BatchProgressTracker
()
logger
.
info
(
"Reading batch from %s..."
,
args
.
input_file
)
...
...
@@ -273,14 +288,28 @@ async def main(args):
))
continue
response_futures
.
append
(
run_request
(
handler_fn
,
request
,
tracker
))
tracker
.
submitted
()
elif
request
.
url
==
"/v1/score"
:
handler_fn
=
(
None
if
openai_serving_scores
is
None
else
openai_serving_scores
.
create_score
)
if
handler_fn
is
None
:
response_futures
.
append
(
make_async_error_request_output
(
request
,
error_msg
=
"The model does not support Scores API"
,
))
continue
response_futures
.
append
(
run_request
(
handler_fn
,
request
,
tracker
))
tracker
.
submitted
()
else
:
response_futures
.
append
(
make_async_error_request_output
(
request
,
error_msg
=
"Only /v1/chat/completions and "
"/v1/embeddings are supported in the batch endpoint."
,
error_msg
=
"Only /v1/chat/completions, /v1/embeddings, and /v1/score "
"are supported in the batch endpoint."
,
))
with
tracker
.
pbar
():
...
...
Prev
1
…
26
27
28
29
30
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment