Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ec5e299c
Commit
ec5e299c
authored
Feb 21, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.7.3' into v0.7.3-dev
parents
47bd229c
ed6e9075
Changes
521
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
862 additions
and
110 deletions
+862
-110
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+6
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+30
-25
vllm/engine/metrics.py
vllm/engine/metrics.py
+1
-1
vllm/engine/multiprocessing/__init__.py
vllm/engine/multiprocessing/__init__.py
+11
-1
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+13
-2
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+13
-2
vllm/engine/output_processor/stop_checker.py
vllm/engine/output_processor/stop_checker.py
+1
-1
vllm/engine/protocol.py
vllm/engine/protocol.py
+10
-0
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+2
-1
vllm/entrypoints/cli/__init__.py
vllm/entrypoints/cli/__init__.py
+0
-0
vllm/entrypoints/cli/main.py
vllm/entrypoints/cli/main.py
+79
-0
vllm/entrypoints/cli/openai.py
vllm/entrypoints/cli/openai.py
+172
-0
vllm/entrypoints/cli/serve.py
vllm/entrypoints/cli/serve.py
+63
-0
vllm/entrypoints/cli/types.py
vllm/entrypoints/cli/types.py
+24
-0
vllm/entrypoints/launcher.py
vllm/entrypoints/launcher.py
+6
-3
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+20
-13
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+101
-21
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+166
-1
vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
.../openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
+36
-24
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+108
-15
No files found.
Too many changes to show.
To preserve performance only
521 of 521+
files are displayed.
Plain diff
Email patch
vllm/engine/async_llm_engine.py
View file @
ec5e299c
...
@@ -1187,6 +1187,12 @@ class AsyncLLMEngine(EngineClient):
...
@@ -1187,6 +1187,12 @@ class AsyncLLMEngine(EngineClient):
async
def
reset_prefix_cache
(
self
)
->
None
:
async
def
reset_prefix_cache
(
self
)
->
None
:
self
.
engine
.
reset_prefix_cache
()
self
.
engine
.
reset_prefix_cache
()
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
self
.
engine
.
sleep
(
level
)
async
def
wake_up
(
self
)
->
None
:
self
.
engine
.
wake_up
()
async
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
async
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
self
.
engine
.
add_lora
(
lora_request
)
self
.
engine
.
add_lora
(
lora_request
)
...
...
vllm/engine/llm_engine.py
View file @
ec5e299c
...
@@ -20,8 +20,7 @@ import vllm.envs as envs
...
@@ -20,8 +20,7 @@ import vllm.envs as envs
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
SchedulerConfig
,
ObservabilityConfig
,
ParallelConfig
,
SchedulerConfig
,
VllmConfig
)
VllmConfig
)
from
vllm.core.scheduler
import
(
ScheduledSequenceGroup
,
Scheduler
,
from
vllm.core.scheduler
import
ScheduledSequenceGroup
,
SchedulerOutputs
SchedulerOutputs
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics_types
import
StatLoggerBase
,
Stats
from
vllm.engine.metrics_types
import
StatLoggerBase
,
Stats
from
vllm.engine.output_processor.interfaces
import
(
from
vllm.engine.output_processor.interfaces
import
(
...
@@ -59,7 +58,8 @@ from vllm.transformers_utils.tokenizer_group import (
...
@@ -59,7 +58,8 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
usage_message
)
from
vllm.utils
import
Counter
,
Device
,
deprecate_kwargs
,
weak_bind
from
vllm.utils
import
(
Counter
,
Device
,
deprecate_kwargs
,
resolve_obj_by_qualname
,
weak_bind
)
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -347,6 +347,11 @@ class LLMEngine:
...
@@ -347,6 +347,11 @@ class LLMEngine:
# Create the scheduler.
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
# GPU and CPU blocks, which are profiled in the distributed executor.
if
isinstance
(
self
.
vllm_config
.
scheduler_config
.
scheduler_cls
,
str
):
Scheduler
=
resolve_obj_by_qualname
(
self
.
vllm_config
.
scheduler_config
.
scheduler_cls
)
else
:
Scheduler
=
self
.
vllm_config
.
scheduler_config
.
scheduler_cls
self
.
scheduler
=
[
self
.
scheduler
=
[
Scheduler
(
Scheduler
(
self
.
scheduler_config
,
self
.
cache_config
,
self
.
lora_config
,
self
.
scheduler_config
,
self
.
cache_config
,
self
.
lora_config
,
...
@@ -437,6 +442,7 @@ class LLMEngine:
...
@@ -437,6 +442,7 @@ class LLMEngine:
@
classmethod
@
classmethod
def
_get_executor_cls
(
cls
,
def
_get_executor_cls
(
cls
,
engine_config
:
VllmConfig
)
->
Type
[
ExecutorBase
]:
engine_config
:
VllmConfig
)
->
Type
[
ExecutorBase
]:
# distributed_executor_backend must be set in VllmConfig.__post_init__
distributed_executor_backend
=
(
distributed_executor_backend
=
(
engine_config
.
parallel_config
.
distributed_executor_backend
)
engine_config
.
parallel_config
.
distributed_executor_backend
)
# Initialize the cluster and specify the executor class.
# Initialize the cluster and specify the executor class.
...
@@ -446,30 +452,29 @@ class LLMEngine:
...
@@ -446,30 +452,29 @@ class LLMEngine:
"distributed_executor_backend must be a subclass of "
"distributed_executor_backend must be a subclass of "
f
"ExecutorBase. Got
{
distributed_executor_backend
}
."
)
f
"ExecutorBase. Got
{
distributed_executor_backend
}
."
)
executor_class
=
distributed_executor_backend
executor_class
=
distributed_executor_backend
elif
engine_config
.
parallel_config
.
world_size
>
1
:
elif
distributed_executor_backend
==
"ray"
:
if
distributed_executor_backend
==
"ray"
:
from
vllm.executor.ray_distributed_executor
import
(
from
vllm.executor.ray_distributed_executor
import
(
RayDistributedExecutor
)
RayDistributedExecutor
)
executor_class
=
RayDistributedExecutor
executor_class
=
RayDistributedExecutor
elif
distributed_executor_backend
==
"mp"
:
elif
distributed_executor_backend
==
"mp"
:
from
vllm.executor.mp_distributed_executor
import
(
from
vllm.executor.mp_distributed_executor
import
(
MultiprocessingDistributedExecutor
)
MultiprocessingDistributedExecutor
)
assert
not
envs
.
VLLM_USE_RAY_SPMD_WORKER
,
(
assert
not
envs
.
VLLM_USE_RAY_SPMD_WORKER
,
(
"multiprocessing distributed executor backend does not "
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1"
)
"support VLLM_USE_RAY_SPMD_WORKER=1"
)
executor_class
=
MultiprocessingDistributedExecutor
executor_class
=
MultiprocessingDistributedExecutor
elif
distributed_executor_backend
==
"uni"
:
elif
distributed_executor_backend
==
"uni"
:
# JAX-style, single-process, multi-device executor.
# JAX-style, single-process, multi-device executor.
from
vllm.executor.uniproc_executor
import
UniProcExecutor
executor_class
=
UniProcExecutor
elif
distributed_executor_backend
==
"external_launcher"
:
# executor with external launcher
from
vllm.executor.uniproc_executor
import
(
# noqa
ExecutorWithExternalLauncher
)
executor_class
=
ExecutorWithExternalLauncher
else
:
from
vllm.executor.uniproc_executor
import
UniProcExecutor
from
vllm.executor.uniproc_executor
import
UniProcExecutor
executor_class
=
UniProcExecutor
executor_class
=
UniProcExecutor
elif
distributed_executor_backend
==
"external_launcher"
:
# executor with external launcher
from
vllm.executor.uniproc_executor
import
(
# noqa
ExecutorWithExternalLauncher
)
executor_class
=
ExecutorWithExternalLauncher
else
:
raise
ValueError
(
"unrecognized distributed_executor_backend: "
f
"
{
distributed_executor_backend
}
"
)
return
executor_class
return
executor_class
@
classmethod
@
classmethod
...
...
vllm/engine/metrics.py
View file @
ec5e299c
...
@@ -237,7 +237,7 @@ class Metrics:
...
@@ -237,7 +237,7 @@ class Metrics:
documentation
=
"Count of successfully processed requests."
,
documentation
=
"Count of successfully processed requests."
,
labelnames
=
labelnames
+
[
Metrics
.
labelname_finish_reason
])
labelnames
=
labelnames
+
[
Metrics
.
labelname_finish_reason
])
# Speculatie decoding stats
# Speculati
v
e decoding stats
self
.
gauge_spec_decode_draft_acceptance_rate
=
self
.
_gauge_cls
(
self
.
gauge_spec_decode_draft_acceptance_rate
=
self
.
_gauge_cls
(
name
=
"vllm:spec_decode_draft_acceptance_rate"
,
name
=
"vllm:spec_decode_draft_acceptance_rate"
,
documentation
=
"Speulative token acceptance rate."
,
documentation
=
"Speulative token acceptance rate."
,
...
...
vllm/engine/multiprocessing/__init__.py
View file @
ec5e299c
...
@@ -127,6 +127,15 @@ class RPCResetPrefixCacheRequest(Enum):
...
@@ -127,6 +127,15 @@ class RPCResetPrefixCacheRequest(Enum):
RESET_PREFIX_CACHE
=
1
RESET_PREFIX_CACHE
=
1
class
RPCSleepRequest
(
Enum
):
SLEEP_LEVEL_1
=
1
SLEEP_LEVEL_2
=
2
class
RPCWakeUpRequest
(
Enum
):
WAKE_UP
=
1
@
dataclass
@
dataclass
class
RPCLoadAdapterRequest
:
class
RPCLoadAdapterRequest
:
lora_request
:
LoRARequest
lora_request
:
LoRARequest
...
@@ -141,7 +150,8 @@ class RPCAdapterLoadedResponse:
...
@@ -141,7 +150,8 @@ class RPCAdapterLoadedResponse:
RPC_REQUEST_T
=
Union
[
RPCProcessRequest
,
RPCAbortRequest
,
RPCStartupRequest
,
RPC_REQUEST_T
=
Union
[
RPCProcessRequest
,
RPCAbortRequest
,
RPCStartupRequest
,
RPCUProfileRequest
,
RPCLoadAdapterRequest
,
RPCUProfileRequest
,
RPCLoadAdapterRequest
,
RPCResetPrefixCacheRequest
]
RPCResetPrefixCacheRequest
,
RPCSleepRequest
,
RPCWakeUpRequest
]
REQUEST_OUTPUTS_T
=
Union
[
List
[
RequestOutput
],
RPCAdapterLoadedResponse
,
REQUEST_OUTPUTS_T
=
Union
[
List
[
RequestOutput
],
RPCAdapterLoadedResponse
,
RPCError
]
RPCError
]
...
...
vllm/engine/multiprocessing/client.py
View file @
ec5e299c
...
@@ -31,8 +31,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
...
@@ -31,8 +31,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCLoadAdapterRequest
,
RPCLoadAdapterRequest
,
RPCProcessRequest
,
RPCProcessRequest
,
RPCResetPrefixCacheRequest
,
RPCResetPrefixCacheRequest
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCSleepRequest
,
RPCStartupRequest
,
RPCUProfileRequest
)
RPCStartupResponse
,
RPCUProfileRequest
,
RPCWakeUpRequest
)
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
# yapf: enable
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.envs
import
VLLM_RPC_TIMEOUT
...
@@ -685,6 +686,16 @@ class MQLLMEngineClient(EngineClient):
...
@@ -685,6 +686,16 @@ class MQLLMEngineClient(EngineClient):
request
=
RPCResetPrefixCacheRequest
.
RESET_PREFIX_CACHE
,
request
=
RPCResetPrefixCacheRequest
.
RESET_PREFIX_CACHE
,
socket
=
self
.
input_socket
)
socket
=
self
.
input_socket
)
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
"""Sleep the engine for a given level"""
return
await
self
.
_send_one_way_rpc_request
(
request
=
RPCSleepRequest
(
level
),
socket
=
self
.
input_socket
)
async
def
wake_up
(
self
)
->
None
:
"""Wake up the engine"""
return
await
self
.
_send_one_way_rpc_request
(
request
=
RPCWakeUpRequest
.
WAKE_UP
,
socket
=
self
.
input_socket
)
async
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
async
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
"""Load a new LoRA adapter into the engine for future requests."""
"""Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests
# Uses the same I/O as generate requests
...
...
vllm/engine/multiprocessing/engine.py
View file @
ec5e299c
...
@@ -20,8 +20,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
...
@@ -20,8 +20,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCLoadAdapterRequest
,
RPCLoadAdapterRequest
,
RPCProcessRequest
,
RPCProcessRequest
,
RPCResetPrefixCacheRequest
,
RPCResetPrefixCacheRequest
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCSleepRequest
,
RPCStartupRequest
,
RPCUProfileRequest
)
RPCStartupResponse
,
RPCUProfileRequest
,
RPCWakeUpRequest
)
# yapf: enable
# yapf: enable
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
...
@@ -242,6 +243,10 @@ class MQLLMEngine:
...
@@ -242,6 +243,10 @@ class MQLLMEngine:
self
.
_handle_load_adapter_request
(
request
)
self
.
_handle_load_adapter_request
(
request
)
elif
isinstance
(
request
,
RPCResetPrefixCacheRequest
):
elif
isinstance
(
request
,
RPCResetPrefixCacheRequest
):
self
.
reset_prefix_cache
()
self
.
reset_prefix_cache
()
elif
isinstance
(
request
,
RPCSleepRequest
):
self
.
sleep
(
request
.
value
)
elif
isinstance
(
request
,
RPCWakeUpRequest
):
self
.
wake_up
()
else
:
else
:
raise
ValueError
(
"Unknown RPCRequest Type: "
raise
ValueError
(
"Unknown RPCRequest Type: "
f
"
{
type
(
request
)
}
"
)
f
"
{
type
(
request
)
}
"
)
...
@@ -369,6 +374,12 @@ class MQLLMEngine:
...
@@ -369,6 +374,12 @@ class MQLLMEngine:
def
reset_prefix_cache
(
self
)
->
bool
:
def
reset_prefix_cache
(
self
)
->
bool
:
return
self
.
engine
.
reset_prefix_cache
()
return
self
.
engine
.
reset_prefix_cache
()
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
self
.
engine
.
sleep
(
level
)
def
wake_up
(
self
)
->
None
:
self
.
engine
.
wake_up
()
def
signal_handler
(
*
_
)
->
None
:
def
signal_handler
(
*
_
)
->
None
:
raise
KeyboardInterrupt
(
"MQLLMEngine terminated"
)
raise
KeyboardInterrupt
(
"MQLLMEngine terminated"
)
...
...
vllm/engine/output_processor/stop_checker.py
View file @
ec5e299c
...
@@ -113,7 +113,7 @@ class StopChecker:
...
@@ -113,7 +113,7 @@ class StopChecker:
stop_string_len
=
len
(
stop_str
)
stop_string_len
=
len
(
stop_str
)
# Avoid searching already-searched text.
# Avoid searching already-searched text.
stop_index
=
output_text
.
find
(
stop_str
,
stop_index
=
output_text
.
find
(
stop_str
,
-
new_char_count
-
stop_string_len
)
1
-
new_char_count
-
stop_string_len
)
if
stop_index
==
-
1
:
if
stop_index
==
-
1
:
continue
continue
...
...
vllm/engine/protocol.py
View file @
ec5e299c
...
@@ -278,6 +278,16 @@ class EngineClient(ABC):
...
@@ -278,6 +278,16 @@ class EngineClient(ABC):
"""Reset the prefix cache"""
"""Reset the prefix cache"""
...
...
@
abstractmethod
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
"""Sleep the engine"""
...
@
abstractmethod
async
def
wake_up
(
self
)
->
None
:
"""Wake up the engine"""
...
@
abstractmethod
@
abstractmethod
async
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
async
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
"""Load a new LoRA adapter into the engine for future requests."""
"""Load a new LoRA adapter into the engine for future requests."""
...
...
vllm/entrypoints/api_server.py
View file @
ec5e299c
...
@@ -127,6 +127,7 @@ async def run_server(args: Namespace,
...
@@ -127,6 +127,7 @@ async def run_server(args: Namespace,
shutdown_task
=
await
serve_http
(
shutdown_task
=
await
serve_http
(
app
,
app
,
sock
=
None
,
host
=
args
.
host
,
host
=
args
.
host
,
port
=
args
.
port
,
port
=
args
.
port
,
log_level
=
args
.
log_level
,
log_level
=
args
.
log_level
,
...
@@ -144,7 +145,7 @@ async def run_server(args: Namespace,
...
@@ -144,7 +145,7 @@ async def run_server(args: Namespace,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
ge
=
1024
,
le
=
65535
)
parser
.
add_argument
(
"--ssl-keyfile"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--ssl-keyfile"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--ssl-certfile"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--ssl-certfile"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--ssl-ca-certs"
,
parser
.
add_argument
(
"--ssl-ca-certs"
,
...
...
vllm/entrypoints/cli/__init__.py
0 → 100644
View file @
ec5e299c
vllm/entrypoints/cli/main.py
0 → 100644
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
# The CLI entrypoint to vLLM.
import
os
import
signal
import
sys
import
vllm.entrypoints.cli.openai
import
vllm.entrypoints.cli.serve
import
vllm.version
from
vllm.logger
import
init_logger
from
vllm.utils
import
FlexibleArgumentParser
logger
=
init_logger
(
__name__
)
CMD_MODULES
=
[
vllm
.
entrypoints
.
cli
.
openai
,
vllm
.
entrypoints
.
cli
.
serve
,
]
def
register_signal_handlers
():
def
signal_handler
(
sig
,
frame
):
sys
.
exit
(
0
)
signal
.
signal
(
signal
.
SIGINT
,
signal_handler
)
signal
.
signal
(
signal
.
SIGTSTP
,
signal_handler
)
def
env_setup
():
# The safest multiprocessing method is `spawn`, as the default `fork` method
# is not compatible with some accelerators. The default method will be
# changing in future versions of Python, so we should use it explicitly when
# possible.
#
# We only set it here in the CLI entrypoint, because changing to `spawn`
# could break some existing code using vLLM as a library. `spawn` will cause
# unexpected behavior if the code is not protected by
# `if __name__ == "__main__":`.
#
# References:
# - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
# - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing
# - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors
# - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders
if
"VLLM_WORKER_MULTIPROC_METHOD"
not
in
os
.
environ
:
logger
.
debug
(
"Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'"
)
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
def
main
():
env_setup
()
parser
=
FlexibleArgumentParser
(
description
=
"vLLM CLI"
)
parser
.
add_argument
(
'-v'
,
'--version'
,
action
=
'version'
,
version
=
vllm
.
version
.
__version__
)
subparsers
=
parser
.
add_subparsers
(
required
=
False
,
dest
=
"subparser"
)
cmds
=
{}
for
cmd_module
in
CMD_MODULES
:
new_cmds
=
cmd_module
.
cmd_init
()
for
cmd
in
new_cmds
:
cmd
.
subparser_init
(
subparsers
).
set_defaults
(
dispatch_function
=
cmd
.
cmd
)
cmds
[
cmd
.
name
]
=
cmd
args
=
parser
.
parse_args
()
if
args
.
subparser
in
cmds
:
cmds
[
args
.
subparser
].
validate
(
args
)
if
hasattr
(
args
,
"dispatch_function"
):
args
.
dispatch_function
(
args
)
else
:
parser
.
print_help
()
if
__name__
==
"__main__"
:
main
()
vllm/entrypoints/cli/openai.py
0 → 100644
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
# Commands that act as an interactive OpenAI API client
import
argparse
import
os
import
signal
import
sys
from
typing
import
List
,
Optional
,
Tuple
from
openai
import
OpenAI
from
openai.types.chat
import
ChatCompletionMessageParam
from
vllm.entrypoints.cli.types
import
CLISubcommand
from
vllm.utils
import
FlexibleArgumentParser
def
_register_signal_handlers
():
def
signal_handler
(
sig
,
frame
):
sys
.
exit
(
0
)
signal
.
signal
(
signal
.
SIGINT
,
signal_handler
)
signal
.
signal
(
signal
.
SIGTSTP
,
signal_handler
)
def
_interactive_cli
(
args
:
argparse
.
Namespace
)
->
Tuple
[
str
,
OpenAI
]:
_register_signal_handlers
()
base_url
=
args
.
url
api_key
=
args
.
api_key
or
os
.
environ
.
get
(
"OPENAI_API_KEY"
,
"EMPTY"
)
openai_client
=
OpenAI
(
api_key
=
api_key
,
base_url
=
base_url
)
if
args
.
model_name
:
model_name
=
args
.
model_name
else
:
available_models
=
openai_client
.
models
.
list
()
model_name
=
available_models
.
data
[
0
].
id
print
(
f
"Using model:
{
model_name
}
"
)
return
model_name
,
openai_client
def
chat
(
system_prompt
:
Optional
[
str
],
model_name
:
str
,
client
:
OpenAI
)
->
None
:
conversation
:
List
[
ChatCompletionMessageParam
]
=
[]
if
system_prompt
is
not
None
:
conversation
.
append
({
"role"
:
"system"
,
"content"
:
system_prompt
})
print
(
"Please enter a message for the chat model:"
)
while
True
:
try
:
input_message
=
input
(
"> "
)
except
EOFError
:
return
conversation
.
append
({
"role"
:
"user"
,
"content"
:
input_message
})
chat_completion
=
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
conversation
)
response_message
=
chat_completion
.
choices
[
0
].
message
output
=
response_message
.
content
conversation
.
append
(
response_message
)
# type: ignore
print
(
output
)
def
_add_query_options
(
parser
:
FlexibleArgumentParser
)
->
FlexibleArgumentParser
:
parser
.
add_argument
(
"--url"
,
type
=
str
,
default
=
"http://localhost:8000/v1"
,
help
=
"url of the running OpenAI-Compatible RESTful API server"
)
parser
.
add_argument
(
"--model-name"
,
type
=
str
,
default
=
None
,
help
=
(
"The model name used in prompt completion, default to "
"the first model in list models API call."
))
parser
.
add_argument
(
"--api-key"
,
type
=
str
,
default
=
None
,
help
=
(
"API key for OpenAI services. If provided, this api key "
"will overwrite the api key obtained through environment variables."
))
return
parser
class
ChatCommand
(
CLISubcommand
):
"""The `chat` subcommand for the vLLM CLI. """
def
__init__
(
self
):
self
.
name
=
"chat"
super
().
__init__
()
@
staticmethod
def
cmd
(
args
:
argparse
.
Namespace
)
->
None
:
model_name
,
client
=
_interactive_cli
(
args
)
system_prompt
=
args
.
system_prompt
conversation
:
List
[
ChatCompletionMessageParam
]
=
[]
if
system_prompt
is
not
None
:
conversation
.
append
({
"role"
:
"system"
,
"content"
:
system_prompt
})
print
(
"Please enter a message for the chat model:"
)
while
True
:
try
:
input_message
=
input
(
"> "
)
except
EOFError
:
return
conversation
.
append
({
"role"
:
"user"
,
"content"
:
input_message
})
chat_completion
=
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
conversation
)
response_message
=
chat_completion
.
choices
[
0
].
message
output
=
response_message
.
content
conversation
.
append
(
response_message
)
# type: ignore
print
(
output
)
def
subparser_init
(
self
,
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
chat_parser
=
subparsers
.
add_parser
(
"chat"
,
help
=
"Generate chat completions via the running API server"
,
usage
=
"vllm chat [options]"
)
_add_query_options
(
chat_parser
)
chat_parser
.
add_argument
(
"--system-prompt"
,
type
=
str
,
default
=
None
,
help
=
(
"The system prompt to be added to the chat template, "
"used for models that support system prompts."
))
return
chat_parser
class
CompleteCommand
(
CLISubcommand
):
"""The `complete` subcommand for the vLLM CLI. """
def
__init__
(
self
):
self
.
name
=
"complete"
super
().
__init__
()
@
staticmethod
def
cmd
(
args
:
argparse
.
Namespace
)
->
None
:
model_name
,
client
=
_interactive_cli
(
args
)
print
(
"Please enter prompt to complete:"
)
while
True
:
input_prompt
=
input
(
"> "
)
completion
=
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
input_prompt
)
output
=
completion
.
choices
[
0
].
text
print
(
output
)
def
subparser_init
(
self
,
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
complete_parser
=
subparsers
.
add_parser
(
"complete"
,
help
=
(
"Generate text completions based on the given prompt "
"via the running API server"
),
usage
=
"vllm complete [options]"
)
_add_query_options
(
complete_parser
)
return
complete_parser
def
cmd_init
()
->
List
[
CLISubcommand
]:
return
[
ChatCommand
(),
CompleteCommand
()]
vllm/entrypoints/cli/serve.py
0 → 100644
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
import
argparse
from
typing
import
List
import
uvloop
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.entrypoints.cli.types
import
CLISubcommand
from
vllm.entrypoints.openai.api_server
import
run_server
from
vllm.entrypoints.openai.cli_args
import
(
make_arg_parser
,
validate_parsed_serve_args
)
from
vllm.utils
import
FlexibleArgumentParser
class
ServeSubcommand
(
CLISubcommand
):
"""The `serve` subcommand for the vLLM CLI. """
def
__init__
(
self
):
self
.
name
=
"serve"
super
().
__init__
()
@
staticmethod
def
cmd
(
args
:
argparse
.
Namespace
)
->
None
:
# The default value of `--model`
if
args
.
model
!=
EngineArgs
.
model
:
raise
ValueError
(
"With `vllm serve`, you should provide the model as a "
"positional argument instead of via the `--model` option."
)
# EngineArgs expects the model name to be passed as --model.
args
.
model
=
args
.
model_tag
uvloop
.
run
(
run_server
(
args
))
def
validate
(
self
,
args
:
argparse
.
Namespace
)
->
None
:
validate_parsed_serve_args
(
args
)
def
subparser_init
(
self
,
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
serve_parser
=
subparsers
.
add_parser
(
"serve"
,
help
=
"Start the vLLM OpenAI Compatible API server"
,
usage
=
"vllm serve <model_tag> [options]"
)
serve_parser
.
add_argument
(
"model_tag"
,
type
=
str
,
help
=
"The model tag to serve"
)
serve_parser
.
add_argument
(
"--config"
,
type
=
str
,
default
=
''
,
required
=
False
,
help
=
"Read CLI options from a config file."
"Must be a YAML with the following options:"
"https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#cli-reference"
)
return
make_arg_parser
(
serve_parser
)
def
cmd_init
()
->
List
[
CLISubcommand
]:
return
[
ServeSubcommand
()]
vllm/entrypoints/cli/types.py
0 → 100644
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
import
argparse
from
vllm.utils
import
FlexibleArgumentParser
class
CLISubcommand
:
"""Base class for CLI argument handlers."""
name
:
str
@
staticmethod
def
cmd
(
args
:
argparse
.
Namespace
)
->
None
:
raise
NotImplementedError
(
"Subclasses should implement this method"
)
def
validate
(
self
,
args
:
argparse
.
Namespace
)
->
None
:
# No validation by default
pass
def
subparser_init
(
self
,
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
raise
NotImplementedError
(
"Subclasses should implement this method"
)
vllm/entrypoints/launcher.py
View file @
ec5e299c
...
@@ -2,8 +2,9 @@
...
@@ -2,8 +2,9 @@
import
asyncio
import
asyncio
import
signal
import
signal
import
socket
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Any
from
typing
import
Any
,
Optional
import
uvicorn
import
uvicorn
from
fastapi
import
FastAPI
,
Request
,
Response
from
fastapi
import
FastAPI
,
Request
,
Response
...
@@ -17,7 +18,8 @@ from vllm.utils import find_process_using_port
...
@@ -17,7 +18,8 @@ from vllm.utils import find_process_using_port
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
async
def
serve_http
(
app
:
FastAPI
,
**
uvicorn_kwargs
:
Any
):
async
def
serve_http
(
app
:
FastAPI
,
sock
:
Optional
[
socket
.
socket
],
**
uvicorn_kwargs
:
Any
):
logger
.
info
(
"Available routes are:"
)
logger
.
info
(
"Available routes are:"
)
for
route
in
app
.
routes
:
for
route
in
app
.
routes
:
methods
=
getattr
(
route
,
"methods"
,
None
)
methods
=
getattr
(
route
,
"methods"
,
None
)
...
@@ -34,7 +36,8 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
...
@@ -34,7 +36,8 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
loop
=
asyncio
.
get_running_loop
()
loop
=
asyncio
.
get_running_loop
()
server_task
=
loop
.
create_task
(
server
.
serve
())
server_task
=
loop
.
create_task
(
server
.
serve
(
sockets
=
[
sock
]
if
sock
else
None
))
def
signal_handler
()
->
None
:
def
signal_handler
()
->
None
:
# prevents the uvicorn signal handler to exit early
# prevents the uvicorn signal handler to exit early
...
...
vllm/entrypoints/llm.py
View file @
ec5e299c
...
@@ -421,7 +421,7 @@ class LLM:
...
@@ -421,7 +421,7 @@ class LLM:
instead pass them via the ``inputs`` parameter.
instead pass them via the ``inputs`` parameter.
"""
"""
runner_type
=
self
.
llm_engine
.
model_config
.
runner_type
runner_type
=
self
.
llm_engine
.
model_config
.
runner_type
if
runner_type
!=
"generate"
:
if
runner_type
not
in
[
"generate"
,
"transcription"
]
:
messages
=
[
messages
=
[
"LLM.generate() is only supported for (conditional) generation "
"LLM.generate() is only supported for (conditional) generation "
"models (XForCausalLM, XForConditionalGeneration)."
,
"models (XForCausalLM, XForConditionalGeneration)."
,
...
@@ -1051,9 +1051,9 @@ class LLM:
...
@@ -1051,9 +1051,9 @@ class LLM:
def
_cross_encoding_score
(
def
_cross_encoding_score
(
self
,
self
,
tokenizer
:
Union
[
AnyTokenizer
]
,
tokenizer
:
AnyTokenizer
,
text_1
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]
],
text_1
:
List
[
str
],
text_2
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]
],
text_2
:
List
[
str
],
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
...
@@ -1176,29 +1176,36 @@ class LLM:
...
@@ -1176,29 +1176,36 @@ class LLM:
if
isinstance
(
text_1
,
(
str
,
dict
)):
if
isinstance
(
text_1
,
(
str
,
dict
)):
# Convert a single prompt to a list.
# Convert a single prompt to a list.
text_1
=
[
text_1
]
text_1
=
[
text_1
]
text_1
=
[
ensure_str
(
t
)
for
t
in
text_1
]
input_text_1
:
List
[
str
]
=
[
ensure_str
(
t
)
for
t
in
text_1
]
if
isinstance
(
text_2
,
(
str
,
dict
)):
if
isinstance
(
text_2
,
(
str
,
dict
)):
# Convert a single prompt to a list.
# Convert a single prompt to a list.
text_2
=
[
text_2
]
text_2
=
[
text_2
]
text_2
=
[
ensure_str
(
t
)
for
t
in
text_2
]
input_text_2
:
List
[
str
]
=
[
ensure_str
(
t
)
for
t
in
text_2
]
if
len
(
text_1
)
>
1
and
len
(
text_1
)
!=
len
(
text_2
):
if
len
(
input_
text_1
)
>
1
and
len
(
input_
text_1
)
!=
len
(
input_
text_2
):
raise
ValueError
(
"Input lengths must be either 1:1, 1:N or N:N"
)
raise
ValueError
(
"Input lengths must be either 1:1, 1:N or N:N"
)
if
len
(
text_1
)
==
0
:
if
len
(
input_
text_1
)
==
0
:
raise
ValueError
(
"At least one text element must be given"
)
raise
ValueError
(
"At least one text element must be given"
)
if
len
(
text_2
)
==
0
:
if
len
(
input_
text_2
)
==
0
:
raise
ValueError
(
"At least one text_pair element must be given"
)
raise
ValueError
(
"At least one text_pair element must be given"
)
if
self
.
llm_engine
.
model_config
.
is_cross_encoder
:
if
self
.
llm_engine
.
model_config
.
is_cross_encoder
:
return
self
.
_cross_encoding_score
(
tokenizer
,
text_1
,
text_2
,
return
self
.
_cross_encoding_score
(
tokenizer
,
input_text_1
,
input_text_2
,
truncate_prompt_tokens
,
use_tqdm
,
truncate_prompt_tokens
,
use_tqdm
,
lora_request
,
lora_request
,
prompt_adapter_request
)
prompt_adapter_request
)
else
:
else
:
return
self
.
_embedding_score
(
tokenizer
,
text_1
,
text_2
,
truncate_prompt_tokens
,
use_tqdm
,
return
self
.
_embedding_score
(
lora_request
,
prompt_adapter_request
)
tokenizer
,
input_text_1
,
# type: ignore[arg-type]
input_text_2
,
# type: ignore[arg-type]
truncate_prompt_tokens
,
use_tqdm
,
lora_request
,
prompt_adapter_request
)
def
start_profile
(
self
)
->
None
:
def
start_profile
(
self
)
->
None
:
self
.
llm_engine
.
start_profile
()
self
.
llm_engine
.
start_profile
()
...
...
vllm/entrypoints/openai/api_server.py
View file @
ec5e299c
...
@@ -10,17 +10,16 @@ import os
...
@@ -10,17 +10,16 @@ import os
import
re
import
re
import
signal
import
signal
import
socket
import
socket
import
sys
import
tempfile
import
tempfile
import
uuid
import
uuid
from
argparse
import
Namespace
from
argparse
import
Namespace
from
contextlib
import
asynccontextmanager
from
contextlib
import
asynccontextmanager
from
functools
import
partial
from
functools
import
partial
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
AsyncIterator
,
Dict
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Annotated
,
AsyncIterator
,
Dict
,
Optional
,
Set
,
Tuple
,
Union
import
uvloop
import
uvloop
from
fastapi
import
APIRouter
,
FastAPI
,
HTTPException
,
Request
from
fastapi
import
APIRouter
,
Depends
,
FastAPI
,
Form
,
HTTPException
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
...
@@ -62,6 +61,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
...
@@ -62,6 +61,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ScoreRequest
,
ScoreResponse
,
ScoreRequest
,
ScoreResponse
,
TokenizeRequest
,
TokenizeRequest
,
TokenizeResponse
,
TokenizeResponse
,
TranscriptionRequest
,
TranscriptionResponse
,
UnloadLoraAdapterRequest
)
UnloadLoraAdapterRequest
)
from
vllm.entrypoints.openai.reasoning_parsers
import
ReasoningParserManager
from
vllm.entrypoints.openai.reasoning_parsers
import
ReasoningParserManager
# yapf: enable
# yapf: enable
...
@@ -76,6 +77,8 @@ from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
...
@@ -76,6 +77,8 @@ from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
from
vllm.entrypoints.openai.serving_score
import
OpenAIServingScores
from
vllm.entrypoints.openai.serving_score
import
OpenAIServingScores
from
vllm.entrypoints.openai.serving_tokenization
import
(
from
vllm.entrypoints.openai.serving_tokenization
import
(
OpenAIServingTokenization
)
OpenAIServingTokenization
)
from
vllm.entrypoints.openai.serving_transcription
import
(
OpenAIServingTranscription
)
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.entrypoints.utils
import
with_cancellation
from
vllm.entrypoints.utils
import
with_cancellation
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -253,6 +256,16 @@ async def build_async_engine_client_from_engine_args(
...
@@ -253,6 +256,16 @@ async def build_async_engine_client_from_engine_args(
multiprocess
.
mark_process_dead
(
engine_process
.
pid
)
multiprocess
.
mark_process_dead
(
engine_process
.
pid
)
async
def
validate_json_request
(
raw_request
:
Request
):
content_type
=
raw_request
.
headers
.
get
(
"content-type"
,
""
).
lower
()
media_type
=
content_type
.
split
(
";"
,
maxsplit
=
1
)[
0
]
if
media_type
!=
"application/json"
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
UNSUPPORTED_MEDIA_TYPE
,
detail
=
"Unsupported Media Type: Only 'application/json' is allowed"
)
router
=
APIRouter
()
router
=
APIRouter
()
...
@@ -319,6 +332,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization:
...
@@ -319,6 +332,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization:
return
request
.
app
.
state
.
openai_serving_tokenization
return
request
.
app
.
state
.
openai_serving_tokenization
def
transcription
(
request
:
Request
)
->
OpenAIServingTranscription
:
return
request
.
app
.
state
.
openai_serving_transcription
def
engine_client
(
request
:
Request
)
->
EngineClient
:
def
engine_client
(
request
:
Request
)
->
EngineClient
:
return
request
.
app
.
state
.
engine_client
return
request
.
app
.
state
.
engine_client
...
@@ -336,7 +353,7 @@ async def ping(raw_request: Request) -> Response:
...
@@ -336,7 +353,7 @@ async def ping(raw_request: Request) -> Response:
return
await
health
(
raw_request
)
return
await
health
(
raw_request
)
@
router
.
post
(
"/tokenize"
)
@
router
.
post
(
"/tokenize"
,
dependencies
=
[
Depends
(
validate_json_request
)]
)
@
with_cancellation
@
with_cancellation
async
def
tokenize
(
request
:
TokenizeRequest
,
raw_request
:
Request
):
async
def
tokenize
(
request
:
TokenizeRequest
,
raw_request
:
Request
):
handler
=
tokenization
(
raw_request
)
handler
=
tokenization
(
raw_request
)
...
@@ -351,7 +368,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
...
@@ -351,7 +368,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
assert_never
(
generator
)
assert_never
(
generator
)
@
router
.
post
(
"/detokenize"
)
@
router
.
post
(
"/detokenize"
,
dependencies
=
[
Depends
(
validate_json_request
)]
)
@
with_cancellation
@
with_cancellation
async
def
detokenize
(
request
:
DetokenizeRequest
,
raw_request
:
Request
):
async
def
detokenize
(
request
:
DetokenizeRequest
,
raw_request
:
Request
):
handler
=
tokenization
(
raw_request
)
handler
=
tokenization
(
raw_request
)
...
@@ -380,7 +397,8 @@ async def show_version():
...
@@ -380,7 +397,8 @@ async def show_version():
return
JSONResponse
(
content
=
ver
)
return
JSONResponse
(
content
=
ver
)
@
router
.
post
(
"/v1/chat/completions"
)
@
router
.
post
(
"/v1/chat/completions"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
with_cancellation
@
with_cancellation
async
def
create_chat_completion
(
request
:
ChatCompletionRequest
,
async
def
create_chat_completion
(
request
:
ChatCompletionRequest
,
raw_request
:
Request
):
raw_request
:
Request
):
...
@@ -401,7 +419,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
...
@@ -401,7 +419,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
return
StreamingResponse
(
content
=
generator
,
media_type
=
"text/event-stream"
)
return
StreamingResponse
(
content
=
generator
,
media_type
=
"text/event-stream"
)
@
router
.
post
(
"/v1/completions"
)
@
router
.
post
(
"/v1/completions"
,
dependencies
=
[
Depends
(
validate_json_request
)]
)
@
with_cancellation
@
with_cancellation
async
def
create_completion
(
request
:
CompletionRequest
,
raw_request
:
Request
):
async
def
create_completion
(
request
:
CompletionRequest
,
raw_request
:
Request
):
handler
=
completion
(
raw_request
)
handler
=
completion
(
raw_request
)
...
@@ -419,7 +437,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
...
@@ -419,7 +437,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return
StreamingResponse
(
content
=
generator
,
media_type
=
"text/event-stream"
)
return
StreamingResponse
(
content
=
generator
,
media_type
=
"text/event-stream"
)
@
router
.
post
(
"/v1/embeddings"
)
@
router
.
post
(
"/v1/embeddings"
,
dependencies
=
[
Depends
(
validate_json_request
)]
)
@
with_cancellation
@
with_cancellation
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
handler
=
embedding
(
raw_request
)
handler
=
embedding
(
raw_request
)
...
@@ -465,7 +483,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
...
@@ -465,7 +483,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
assert_never
(
generator
)
assert_never
(
generator
)
@
router
.
post
(
"/pooling"
)
@
router
.
post
(
"/pooling"
,
dependencies
=
[
Depends
(
validate_json_request
)]
)
@
with_cancellation
@
with_cancellation
async
def
create_pooling
(
request
:
PoolingRequest
,
raw_request
:
Request
):
async
def
create_pooling
(
request
:
PoolingRequest
,
raw_request
:
Request
):
handler
=
pooling
(
raw_request
)
handler
=
pooling
(
raw_request
)
...
@@ -483,7 +501,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
...
@@ -483,7 +501,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
assert_never
(
generator
)
assert_never
(
generator
)
@
router
.
post
(
"/score"
)
@
router
.
post
(
"/score"
,
dependencies
=
[
Depends
(
validate_json_request
)]
)
@
with_cancellation
@
with_cancellation
async
def
create_score
(
request
:
ScoreRequest
,
raw_request
:
Request
):
async
def
create_score
(
request
:
ScoreRequest
,
raw_request
:
Request
):
handler
=
score
(
raw_request
)
handler
=
score
(
raw_request
)
...
@@ -501,7 +519,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
...
@@ -501,7 +519,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
assert_never
(
generator
)
assert_never
(
generator
)
@
router
.
post
(
"/v1/score"
)
@
router
.
post
(
"/v1/score"
,
dependencies
=
[
Depends
(
validate_json_request
)]
)
@
with_cancellation
@
with_cancellation
async
def
create_score_v1
(
request
:
ScoreRequest
,
raw_request
:
Request
):
async
def
create_score_v1
(
request
:
ScoreRequest
,
raw_request
:
Request
):
logger
.
warning
(
logger
.
warning
(
...
@@ -511,7 +529,32 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
...
@@ -511,7 +529,32 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return
await
create_score
(
request
,
raw_request
)
return
await
create_score
(
request
,
raw_request
)
@
router
.
post
(
"/rerank"
)
@
router
.
post
(
"/v1/audio/transcriptions"
)
@
with_cancellation
async
def
create_transcriptions
(
request
:
Annotated
[
TranscriptionRequest
,
Form
()],
raw_request
:
Request
):
handler
=
transcription
(
raw_request
)
if
handler
is
None
:
return
base
(
raw_request
).
create_error_response
(
message
=
"The model does not support Transcriptions API"
)
audio_data
=
await
request
.
file
.
read
()
generator
=
await
handler
.
create_transcription
(
audio_data
,
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
elif
isinstance
(
generator
,
TranscriptionResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
return
StreamingResponse
(
content
=
generator
,
media_type
=
"text/event-stream"
)
@
router
.
post
(
"/rerank"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
with_cancellation
@
with_cancellation
async
def
do_rerank
(
request
:
RerankRequest
,
raw_request
:
Request
):
async
def
do_rerank
(
request
:
RerankRequest
,
raw_request
:
Request
):
handler
=
rerank
(
raw_request
)
handler
=
rerank
(
raw_request
)
...
@@ -528,7 +571,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
...
@@ -528,7 +571,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
assert_never
(
generator
)
assert_never
(
generator
)
@
router
.
post
(
"/v1/rerank"
)
@
router
.
post
(
"/v1/rerank"
,
dependencies
=
[
Depends
(
validate_json_request
)]
)
@
with_cancellation
@
with_cancellation
async
def
do_rerank_v1
(
request
:
RerankRequest
,
raw_request
:
Request
):
async
def
do_rerank_v1
(
request
:
RerankRequest
,
raw_request
:
Request
):
logger
.
warning_once
(
logger
.
warning_once
(
...
@@ -539,7 +582,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request):
...
@@ -539,7 +582,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request):
return
await
do_rerank
(
request
,
raw_request
)
return
await
do_rerank
(
request
,
raw_request
)
@
router
.
post
(
"/v2/rerank"
)
@
router
.
post
(
"/v2/rerank"
,
dependencies
=
[
Depends
(
validate_json_request
)]
)
@
with_cancellation
@
with_cancellation
async
def
do_rerank_v2
(
request
:
RerankRequest
,
raw_request
:
Request
):
async
def
do_rerank_v2
(
request
:
RerankRequest
,
raw_request
:
Request
):
return
await
do_rerank
(
request
,
raw_request
)
return
await
do_rerank
(
request
,
raw_request
)
...
@@ -582,8 +625,26 @@ if envs.VLLM_SERVER_DEV_MODE:
...
@@ -582,8 +625,26 @@ if envs.VLLM_SERVER_DEV_MODE:
await
engine_client
(
raw_request
).
reset_prefix_cache
()
await
engine_client
(
raw_request
).
reset_prefix_cache
()
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/sleep"
)
async
def
sleep
(
raw_request
:
Request
):
# get POST params
level
=
raw_request
.
query_params
.
get
(
"level"
,
"1"
)
logger
.
info
(
"sleep the engine with level %s"
,
level
)
await
engine_client
(
raw_request
).
sleep
(
int
(
level
))
# FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response.
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/wake_up"
)
async
def
wake_up
(
raw_request
:
Request
):
logger
.
info
(
"wake up the engine"
)
await
engine_client
(
raw_request
).
wake_up
()
# FIXME: in v0 with frontend multiprocessing, the wake-up command
# is sent but does not finish yet when we return a response.
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/invocations"
)
@
router
.
post
(
"/invocations"
,
dependencies
=
[
Depends
(
validate_json_request
)]
)
async
def
invocations
(
raw_request
:
Request
):
async
def
invocations
(
raw_request
:
Request
):
"""
"""
For SageMaker, routes requests to other handlers based on model `task`.
For SageMaker, routes requests to other handlers based on model `task`.
...
@@ -633,7 +694,8 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
...
@@ -633,7 +694,8 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
"Lora dynamic loading & unloading is enabled in the API server. "
"Lora dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!"
)
"This should ONLY be used for local development!"
)
@
router
.
post
(
"/v1/load_lora_adapter"
)
@
router
.
post
(
"/v1/load_lora_adapter"
,
dependencies
=
[
Depends
(
validate_json_request
)])
async
def
load_lora_adapter
(
request
:
LoadLoraAdapterRequest
,
async
def
load_lora_adapter
(
request
:
LoadLoraAdapterRequest
,
raw_request
:
Request
):
raw_request
:
Request
):
handler
=
models
(
raw_request
)
handler
=
models
(
raw_request
)
...
@@ -644,7 +706,8 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
...
@@ -644,7 +706,8 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
return
Response
(
status_code
=
200
,
content
=
response
)
return
Response
(
status_code
=
200
,
content
=
response
)
@
router
.
post
(
"/v1/unload_lora_adapter"
)
@
router
.
post
(
"/v1/unload_lora_adapter"
,
dependencies
=
[
Depends
(
validate_json_request
)])
async
def
unload_lora_adapter
(
request
:
UnloadLoraAdapterRequest
,
async
def
unload_lora_adapter
(
request
:
UnloadLoraAdapterRequest
,
raw_request
:
Request
):
raw_request
:
Request
):
handler
=
models
(
raw_request
)
handler
=
models
(
raw_request
)
...
@@ -753,7 +816,9 @@ async def init_app_state(
...
@@ -753,7 +816,9 @@ async def init_app_state(
state
.
log_stats
=
not
args
.
disable_log_stats
state
.
log_stats
=
not
args
.
disable_log_stats
resolved_chat_template
=
load_chat_template
(
args
.
chat_template
)
resolved_chat_template
=
load_chat_template
(
args
.
chat_template
)
logger
.
info
(
"Using supplied chat template:
\n
%s"
,
resolved_chat_template
)
if
resolved_chat_template
is
not
None
:
logger
.
info
(
"Using supplied chat template:
\n
%s"
,
resolved_chat_template
)
state
.
openai_serving_models
=
OpenAIServingModels
(
state
.
openai_serving_models
=
OpenAIServingModels
(
engine_client
=
engine_client
,
engine_client
=
engine_client
,
...
@@ -821,6 +886,12 @@ async def init_app_state(
...
@@ -821,6 +886,12 @@ async def init_app_state(
chat_template
=
resolved_chat_template
,
chat_template
=
resolved_chat_template
,
chat_template_content_format
=
args
.
chat_template_content_format
,
chat_template_content_format
=
args
.
chat_template_content_format
,
)
)
state
.
openai_serving_transcription
=
OpenAIServingTranscription
(
engine_client
,
model_config
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
)
if
model_config
.
runner_type
==
"transcription"
else
None
state
.
task
=
model_config
.
task
state
.
task
=
model_config
.
task
...
@@ -831,6 +902,7 @@ def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
...
@@ -831,6 +902,7 @@ def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
sock
=
socket
.
socket
(
family
=
family
,
type
=
socket
.
SOCK_STREAM
)
sock
=
socket
.
socket
(
family
=
family
,
type
=
socket
.
SOCK_STREAM
)
sock
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEADDR
,
1
)
sock
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEADDR
,
1
)
sock
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEPORT
,
1
)
sock
.
bind
(
addr
)
sock
.
bind
(
addr
)
return
sock
return
sock
...
@@ -878,8 +950,17 @@ async def run_server(args, **uvicorn_kwargs) -> None:
...
@@ -878,8 +950,17 @@ async def run_server(args, **uvicorn_kwargs) -> None:
model_config
=
await
engine_client
.
get_model_config
()
model_config
=
await
engine_client
.
get_model_config
()
await
init_app_state
(
engine_client
,
model_config
,
app
.
state
,
args
)
await
init_app_state
(
engine_client
,
model_config
,
app
.
state
,
args
)
def
_listen_addr
(
a
:
str
)
->
str
:
if
is_valid_ipv6_address
(
a
):
return
'['
+
a
+
']'
return
a
or
"0.0.0.0"
logger
.
info
(
"Starting vLLM API server on http://%s:%d"
,
_listen_addr
(
sock_addr
[
0
]),
sock_addr
[
1
])
shutdown_task
=
await
serve_http
(
shutdown_task
=
await
serve_http
(
app
,
app
,
sock
=
sock
,
host
=
args
.
host
,
host
=
args
.
host
,
port
=
args
.
port
,
port
=
args
.
port
,
log_level
=
args
.
uvicorn_log_level
,
log_level
=
args
.
uvicorn_log_level
,
...
@@ -888,8 +969,6 @@ async def run_server(args, **uvicorn_kwargs) -> None:
...
@@ -888,8 +969,6 @@ async def run_server(args, **uvicorn_kwargs) -> None:
ssl_certfile
=
args
.
ssl_certfile
,
ssl_certfile
=
args
.
ssl_certfile
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
ssl_cert_reqs
=
args
.
ssl_cert_reqs
,
ssl_cert_reqs
=
args
.
ssl_cert_reqs
,
# Workaround to work on macOS
fd
=
sock
.
fileno
()
if
sys
.
platform
.
startswith
(
"darwin"
)
else
None
,
**
uvicorn_kwargs
,
**
uvicorn_kwargs
,
)
)
...
@@ -901,7 +980,8 @@ async def run_server(args, **uvicorn_kwargs) -> None:
...
@@ -901,7 +980,8 @@ async def run_server(args, **uvicorn_kwargs) -> None:
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
# NOTE(simon):
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
# This section should be in sync with vllm/entrypoints/cli/main.py for CLI
# entrypoints.
parser
=
FlexibleArgumentParser
(
parser
=
FlexibleArgumentParser
(
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
parser
=
make_arg_parser
(
parser
)
parser
=
make_arg_parser
(
parser
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
ec5e299c
...
@@ -8,9 +8,10 @@ from argparse import Namespace
...
@@ -8,9 +8,10 @@ from argparse import Namespace
from
typing
import
Any
,
ClassVar
,
Dict
,
List
,
Literal
,
Optional
,
Set
,
Union
from
typing
import
Any
,
ClassVar
,
Dict
,
List
,
Literal
,
Optional
,
Set
,
Union
import
torch
import
torch
from
fastapi
import
UploadFile
from
pydantic
import
(
BaseModel
,
ConfigDict
,
Field
,
TypeAdapter
,
from
pydantic
import
(
BaseModel
,
ConfigDict
,
Field
,
TypeAdapter
,
ValidationInfo
,
field_validator
,
model_validator
)
ValidationInfo
,
field_validator
,
model_validator
)
from
typing_extensions
import
Annotated
from
typing_extensions
import
Annotated
,
TypeAlias
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -311,6 +312,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -311,6 +312,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
description
=
(
"Additional kwargs to pass to the template renderer. "
description
=
(
"Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."
),
"Will be accessible by the chat template."
),
)
)
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
Field
(
default
=
None
,
description
=
(
"Additional kwargs to pass to the HF processor."
),
)
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
Field
(
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
Field
(
default
=
None
,
default
=
None
,
description
=
(
"If specified, the output will follow the JSON schema."
),
description
=
(
"If specified, the output will follow the JSON schema."
),
...
@@ -1426,3 +1431,163 @@ class LoadLoraAdapterRequest(BaseModel):
...
@@ -1426,3 +1431,163 @@ class LoadLoraAdapterRequest(BaseModel):
class
UnloadLoraAdapterRequest
(
BaseModel
):
class
UnloadLoraAdapterRequest
(
BaseModel
):
lora_name
:
str
lora_name
:
str
lora_int_id
:
Optional
[
int
]
=
Field
(
default
=
None
)
lora_int_id
:
Optional
[
int
]
=
Field
(
default
=
None
)
## Protocols for Audio
AudioResponseFormat
:
TypeAlias
=
Literal
[
"json"
,
"text"
,
"srt"
,
"verbose_json"
,
"vtt"
]
class
TranscriptionRequest
(
OpenAIBaseModel
):
# Ordered by official OpenAI API documentation
#https://platform.openai.com/docs/api-reference/audio/createTranscription
file
:
UploadFile
"""
The audio file object (not file name) to transcribe, in one of these
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
"""
model
:
str
"""ID of the model to use.
"""
language
:
Optional
[
str
]
=
None
"""The language of the input audio.
Supplying the input language in
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
will improve accuracy and latency.
"""
prompt
:
str
=
Field
(
default
=
""
)
"""An optional text to guide the model's style or continue a previous audio
segment.
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
should match the audio language.
"""
response_format
:
AudioResponseFormat
=
Field
(
default
=
"json"
)
"""
The format of the output, in one of these options: `json`, `text`, `srt`,
`verbose_json`, or `vtt`.
"""
## TODO (varun) : Support if set to 0, certain thresholds are met !!
temperature
:
float
=
Field
(
default
=
0.0
)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
timestamp_granularities
:
List
[
Literal
[
"word"
,
"segment"
]]
=
Field
(
alias
=
"timestamp_granularities[]"
,
default
=
[])
"""The timestamp granularities to populate for this transcription.
`response_format` must be set `verbose_json` to use timestamp granularities.
Either or both of these options are supported: `word`, or `segment`. Note:
There is no additional latency for segment timestamps, but generating word
timestamps incurs additional latency.
"""
# Default sampling parameters for transcription requests.
_DEFAULT_SAMPLING_PARAMS
:
dict
=
{
"temperature"
:
0
,
}
def
to_sampling_params
(
self
,
default_max_tokens
:
int
,
default_sampling_params
:
Optional
[
dict
]
=
None
)
->
SamplingParams
:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens
=
default_max_tokens
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
# Default parameters
if
(
temperature
:
=
self
.
temperature
)
is
None
:
temperature
=
default_sampling_params
.
get
(
"temperature"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
return
SamplingParams
.
from_optional
(
temperature
=
temperature
,
max_tokens
=
max_tokens
)
# Transcription response objects
class
TranscriptionResponse
(
OpenAIBaseModel
):
text
:
str
"""The transcribed text."""
class
TranscriptionWord
(
OpenAIBaseModel
):
end
:
float
"""End time of the word in seconds."""
start
:
float
"""Start time of the word in seconds."""
word
:
str
"""The text content of the word."""
class
TranscriptionSegment
(
OpenAIBaseModel
):
id
:
int
"""Unique identifier of the segment."""
avg_logprob
:
float
"""Average logprob of the segment.
If the value is lower than -1, consider the logprobs failed.
"""
compression_ratio
:
float
"""Compression ratio of the segment.
If the value is greater than 2.4, consider the compression failed.
"""
end
:
float
"""End time of the segment in seconds."""
no_speech_prob
:
float
"""Probability of no speech in the segment.
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
this segment silent.
"""
seek
:
int
"""Seek offset of the segment."""
start
:
float
"""Start time of the segment in seconds."""
temperature
:
float
"""Temperature parameter used for generating the segment."""
text
:
str
"""Text content of the segment."""
tokens
:
List
[
int
]
"""Array of token IDs for the text content."""
class
TranscriptionResponseVerbose
(
OpenAIBaseModel
):
duration
:
str
"""The duration of the input audio."""
language
:
str
"""The language of the input audio."""
text
:
str
"""The transcribed text."""
segments
:
Optional
[
List
[
TranscriptionSegment
]]
=
None
"""Segments of the transcribed text and their corresponding details."""
words
:
Optional
[
List
[
TranscriptionWord
]]
=
None
"""Extracted words and their corresponding timestamps."""
vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
View file @
ec5e299c
...
@@ -67,6 +67,8 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
...
@@ -67,6 +67,8 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
]):
]):
return
None
return
None
# Check if <think> is present in previous or delta.
# Keep compatibility with models that don't generate <think> tokens.
if
self
.
think_start_token_id
in
previous_token_ids
:
if
self
.
think_start_token_id
in
previous_token_ids
:
if
self
.
think_end_token_id
in
delta_token_ids
:
if
self
.
think_end_token_id
in
delta_token_ids
:
# <think> in previous, </think> in delta,
# <think> in previous, </think> in delta,
...
@@ -85,7 +87,6 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
...
@@ -85,7 +87,6 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# reasoning content continues
# reasoning content continues
return
DeltaMessage
(
reasoning_content
=
delta_text
)
return
DeltaMessage
(
reasoning_content
=
delta_text
)
elif
self
.
think_start_token_id
in
delta_token_ids
:
elif
self
.
think_start_token_id
in
delta_token_ids
:
logger
.
info
(
delta_text
)
if
self
.
think_end_token_id
in
delta_token_ids
:
if
self
.
think_end_token_id
in
delta_token_ids
:
# <think> in delta, </think> in delta, extract reasoning content
# <think> in delta, </think> in delta, extract reasoning content
start_index
=
delta_text
.
find
(
self
.
think_start_token
)
start_index
=
delta_text
.
find
(
self
.
think_start_token
)
...
@@ -101,35 +102,46 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
...
@@ -101,35 +102,46 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# reasoning content continues
# reasoning content continues
return
DeltaMessage
(
reasoning_content
=
delta_text
)
return
DeltaMessage
(
reasoning_content
=
delta_text
)
else
:
else
:
# No <think> in previous or delta, reasoning content continues.
# No <think> in previous or delta, also need to check for </think>.
return
DeltaMessage
(
content
=
delta_text
)
# Because the model may have generated </think> without <think>
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if
self
.
think_end_token_id
in
delta_token_ids
:
# </think> in delta with more tokens,
# extract reasoning content and content
end_index
=
delta_text
.
find
(
self
.
think_end_token
)
reasoning_content
=
delta_text
[:
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
think_end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
)
elif
self
.
think_end_token_id
in
previous_token_ids
:
# </think> in previous, thinking content ends
return
DeltaMessage
(
content
=
delta_text
)
else
:
# no </think> in previous or delta, reasoning content continues
return
DeltaMessage
(
reasoning_content
=
delta_text
)
def
extract_reasoning_content
(
def
extract_reasoning_content
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
Tuple
[
Optional
[
str
],
Optional
[
str
]]:
)
->
Tuple
[
Optional
[
str
],
Optional
[
str
]]:
# Check if the model output contains the <think> tokens.
# DeepSeek R1 doesn't generate <think> now.
if
(
self
.
think_start_token
not
in
model_output
# Thus we assume the reasoning content is always at the start.
or
self
.
think_end_token
not
in
model_output
):
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
return
None
,
model_output
if
self
.
think_end_token
not
in
model_output
:
return
model_output
,
None
else
:
else
:
# Add a start token if it's missing to keep compatibility.
if
self
.
think_start_token
not
in
model_output
:
model_output
=
f
"
{
self
.
think_start_token
}{
model_output
}
"
# Use a regex to find the reasoning content
# Use a regex to find the reasoning content
reasoning_content
=
self
.
reasoning_regex
.
findall
(
model_output
)[
0
]
reasoning_content
=
self
.
reasoning_regex
.
findall
(
model_output
)[
0
]
# Remove the reasoning content from the model output
end_index
=
len
(
# Although deepseek's <think> token is always at the
f
"
{
self
.
think_start_token
}{
reasoning_content
}{
self
.
think_end_token
}
"
# beginning of the line, we cannot guarantee that the
)
# other models will follow this convention.
final_output
=
model_output
[
end_index
:]
# Therefore, we need to add :start_index.
start_index
=
model_output
.
find
(
self
.
think_start_token
)
if
len
(
final_output
)
==
0
:
if
start_index
!=
-
1
:
return
reasoning_content
,
None
end_index
=
start_index
+
len
(
f
"
{
self
.
think_start_token
}{
reasoning_content
}{
self
.
think_end_token
}
"
return
reasoning_content
,
final_output
)
model_output
=
model_output
[:
start_index
]
+
\
model_output
[
end_index
:]
if
len
(
model_output
)
==
0
:
return
reasoning_content
,
None
return
reasoning_content
,
model_output
vllm/entrypoints/openai/run_batch.py
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
asyncio
import
tempfile
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
io
import
StringIO
from
io
import
StringIO
from
typing
import
Awaitable
,
Callable
,
List
,
Optional
from
typing
import
Awaitable
,
Callable
,
List
,
Optional
...
@@ -51,6 +52,13 @@ def parse_args():
...
@@ -51,6 +52,13 @@ def parse_args():
help
=
"The path or url to a single output file. Currently supports "
help
=
"The path or url to a single output file. Currently supports "
"local file paths, or web (http or https) urls. If a URL is specified,"
"local file paths, or web (http or https) urls. If a URL is specified,"
" the file should be available via HTTP PUT."
)
" the file should be available via HTTP PUT."
)
parser
.
add_argument
(
"--output-tmp-dir"
,
type
=
str
,
default
=
None
,
help
=
"The directory to store the output file before uploading it "
"to the output URL."
,
)
parser
.
add_argument
(
"--response-role"
,
parser
.
add_argument
(
"--response-role"
,
type
=
nullable_str
,
type
=
nullable_str
,
default
=
"assistant"
,
default
=
"assistant"
,
...
@@ -134,17 +142,107 @@ async def read_file(path_or_url: str) -> str:
...
@@ -134,17 +142,107 @@ async def read_file(path_or_url: str) -> str:
return
f
.
read
()
return
f
.
read
()
async
def
write_file
(
path_or_url
:
str
,
data
:
str
)
->
None
:
async
def
write_local_file
(
output_path
:
str
,
batch_outputs
:
List
[
BatchRequestOutput
])
->
None
:
"""
Write the responses to a local file.
output_path: The path to write the responses to.
batch_outputs: The list of batch outputs to write.
"""
# We should make this async, but as long as run_batch runs as a
# standalone program, blocking the event loop won't effect performance.
with
open
(
output_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
for
o
in
batch_outputs
:
print
(
o
.
model_dump_json
(),
file
=
f
)
async
def
upload_data
(
output_url
:
str
,
data_or_file
:
str
,
from_file
:
bool
)
->
None
:
"""
Upload a local file to a URL.
output_url: The URL to upload the file to.
data_or_file: Either the data to upload or the path to the file to upload.
from_file: If True, data_or_file is the path to the file to upload.
"""
# Timeout is a common issue when uploading large files.
# We retry max_retries times before giving up.
max_retries
=
5
# Number of seconds to wait before retrying.
delay
=
5
for
attempt
in
range
(
1
,
max_retries
+
1
):
try
:
# We increase the timeout to 1000 seconds to allow
# for large files (default is 300).
async
with
aiohttp
.
ClientSession
(
timeout
=
aiohttp
.
ClientTimeout
(
total
=
1000
))
as
session
:
if
from_file
:
with
open
(
data_or_file
,
"rb"
)
as
file
:
async
with
session
.
put
(
output_url
,
data
=
file
)
as
response
:
if
response
.
status
!=
200
:
raise
Exception
(
f
"Failed to upload file.
\n
"
f
"Status:
{
response
.
status
}
\n
"
f
"Response:
{
response
.
text
()
}
"
)
else
:
async
with
session
.
put
(
output_url
,
data
=
data_or_file
)
as
response
:
if
response
.
status
!=
200
:
raise
Exception
(
f
"Failed to upload data.
\n
"
f
"Status:
{
response
.
status
}
\n
"
f
"Response:
{
response
.
text
()
}
"
)
except
Exception
as
e
:
if
attempt
<
max_retries
:
logger
.
error
(
f
"Failed to upload data (attempt
{
attempt
}
). "
f
"Error message:
{
str
(
e
)
}
.
\n
Retrying in
{
delay
}
seconds..."
)
await
asyncio
.
sleep
(
delay
)
else
:
raise
Exception
(
f
"Failed to upload data (attempt
{
attempt
}
). "
f
"Error message:
{
str
(
e
)
}
."
)
from
e
async
def
write_file
(
path_or_url
:
str
,
batch_outputs
:
List
[
BatchRequestOutput
],
output_tmp_dir
:
str
)
->
None
:
"""
Write batch_outputs to a file or upload to a URL.
path_or_url: The path or URL to write batch_outputs to.
batch_outputs: The list of batch outputs to write.
output_tmp_dir: The directory to store the output file before uploading it
to the output URL.
"""
if
path_or_url
.
startswith
(
"http://"
)
or
path_or_url
.
startswith
(
"https://"
):
if
path_or_url
.
startswith
(
"http://"
)
or
path_or_url
.
startswith
(
"https://"
):
async
with
aiohttp
.
ClientSession
()
as
session
,
\
if
output_tmp_dir
is
None
:
session
.
put
(
path_or_url
,
data
=
data
.
encode
(
"utf-8"
)):
logger
.
info
(
"Writing outputs to memory buffer"
)
pass
output_buffer
=
StringIO
()
for
o
in
batch_outputs
:
print
(
o
.
model_dump_json
(),
file
=
output_buffer
)
output_buffer
.
seek
(
0
)
logger
.
info
(
"Uploading outputs to %s"
,
path_or_url
)
await
upload_data
(
path_or_url
,
output_buffer
.
read
().
strip
().
encode
(
"utf-8"
),
from_file
=
False
,
)
else
:
# Write responses to a temporary file and then upload it to the URL.
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
encoding
=
"utf-8"
,
dir
=
output_tmp_dir
,
prefix
=
"tmp_batch_output_"
,
suffix
=
".jsonl"
,
)
as
f
:
logger
.
info
(
"Writing outputs to temporary local file %s"
,
f
.
name
)
await
write_local_file
(
f
.
name
,
batch_outputs
)
logger
.
info
(
"Uploading outputs to %s"
,
path_or_url
)
await
upload_data
(
path_or_url
,
f
.
name
,
from_file
=
True
)
else
:
else
:
# We should make this async, but as long as this is always run as a
logger
.
info
(
"Writing outputs to local file %s"
,
path_or_url
)
# standalone program, blocking the event loop won't effect performance
await
write_local_file
(
path_or_url
,
batch_outputs
)
# in this particular case.
with
open
(
path_or_url
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
data
)
def
make_error_request_output
(
request
:
BatchRequestInput
,
def
make_error_request_output
(
request
:
BatchRequestInput
,
...
@@ -317,12 +415,7 @@ async def main(args):
...
@@ -317,12 +415,7 @@ async def main(args):
with
tracker
.
pbar
():
with
tracker
.
pbar
():
responses
=
await
asyncio
.
gather
(
*
response_futures
)
responses
=
await
asyncio
.
gather
(
*
response_futures
)
output_buffer
=
StringIO
()
await
write_file
(
args
.
output_file
,
responses
,
args
.
output_tmp_dir
)
for
response
in
responses
:
print
(
response
.
model_dump_json
(),
file
=
output_buffer
)
output_buffer
.
seek
(
0
)
await
write_file
(
args
.
output_file
,
output_buffer
.
read
().
strip
())
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Prev
1
…
13
14
15
16
17
18
19
20
21
…
27
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment