Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cc7f22a8
Commit
cc7f22a8
authored
Jun 11, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.1' into v0.9.1-ori
parents
b9ea0c09
b6553be1
Changes
1000
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
310 additions
and
110 deletions
+310
-110
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+8
-82
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+1
-0
vllm/engine/output_processor/interfaces.py
vllm/engine/output_processor/interfaces.py
+1
-0
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+1
-0
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+1
-0
vllm/engine/output_processor/stop_checker.py
vllm/engine/output_processor/stop_checker.py
+2
-1
vllm/engine/output_processor/util.py
vllm/engine/output_processor/util.py
+1
-0
vllm/engine/protocol.py
vllm/engine/protocol.py
+15
-8
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+3
-2
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+4
-3
vllm/entrypoints/cli/benchmark/base.py
vllm/entrypoints/cli/benchmark/base.py
+1
-0
vllm/entrypoints/cli/benchmark/latency.py
vllm/entrypoints/cli/benchmark/latency.py
+1
-0
vllm/entrypoints/cli/benchmark/main.py
vllm/entrypoints/cli/benchmark/main.py
+1
-0
vllm/entrypoints/cli/benchmark/serve.py
vllm/entrypoints/cli/benchmark/serve.py
+1
-0
vllm/entrypoints/cli/benchmark/throughput.py
vllm/entrypoints/cli/benchmark/throughput.py
+1
-0
vllm/entrypoints/cli/collect_env.py
vllm/entrypoints/cli/collect_env.py
+1
-0
vllm/entrypoints/cli/main.py
vllm/entrypoints/cli/main.py
+5
-2
vllm/entrypoints/cli/openai.py
vllm/entrypoints/cli/openai.py
+1
-0
vllm/entrypoints/cli/run_batch.py
vllm/entrypoints/cli/run_batch.py
+62
-0
vllm/entrypoints/cli/serve.py
vllm/entrypoints/cli/serve.py
+199
-12
No files found.
Too many changes to show.
To preserve performance only
1000 of 1000+
files are displayed.
Plain diff
Email patch
vllm/engine/multiprocessing/client.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
copy
import
pickle
from
contextlib
import
contextmanager
,
suppress
from
typing
import
(
Any
,
AsyncGenerator
,
Dict
,
Iterator
,
List
,
Mapping
,
Optional
,
Union
,
cast
,
overload
)
Optional
,
Union
,
cast
)
import
cloudpickle
import
psutil
import
zmq
import
zmq.asyncio
from
typing_extensions
import
deprecated
from
zmq
import
Frame
# type: ignore[attr-defined]
from
zmq.asyncio
import
Socket
...
...
@@ -48,7 +48,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.utils
import
Device
,
deprecate_kwargs
from
vllm.utils
import
Device
logger
=
init_logger
(
__name__
)
...
...
@@ -441,7 +441,6 @@ class MQLLMEngineClient(EngineClient):
def
dead_error
(
self
)
->
BaseException
:
return
ENGINE_DEAD_ERROR
(
self
.
_errored_with
)
@
overload
def
generate
(
self
,
prompt
:
PromptType
,
...
...
@@ -451,39 +450,6 @@ class MQLLMEngineClient(EngineClient):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
@
overload
@
deprecated
(
"'inputs' will be renamed to 'prompt"
)
def
generate
(
self
,
*
,
inputs
:
PromptType
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
generate
(
self
,
prompt
:
Optional
[
PromptType
]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
# DEPRECATED
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate outputs for a request.
...
...
@@ -505,16 +471,12 @@ class MQLLMEngineClient(EngineClient):
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
"""
if
inputs
is
not
None
:
prompt
=
inputs
assert
(
prompt
is
not
None
and
sampling_params
is
not
None
and
request_id
is
not
None
)
return
self
.
_process_request
(
prompt
,
sampling_params
,
request_id
,
lora_request
,
trace_headers
,
prompt_adapter_request
,
priority
)
return
cast
(
AsyncGenerator
[
RequestOutput
,
None
],
self
.
_process_request
(
prompt
,
sampling_params
,
request_id
,
lora_request
,
trace_headers
,
prompt_adapter_request
,
priority
))
@
overload
def
encode
(
self
,
prompt
:
PromptType
,
...
...
@@ -523,37 +485,6 @@ class MQLLMEngineClient(EngineClient):
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
PoolingRequestOutput
,
None
]:
...
@
overload
@
deprecated
(
"'inputs' will be renamed to 'prompt"
)
def
encode
(
self
,
*
,
inputs
:
PromptType
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
PoolingRequestOutput
,
None
]:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
encode
(
self
,
prompt
:
Optional
[
PromptType
]
=
None
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
# DEPRECATED
)
->
AsyncGenerator
[
PoolingRequestOutput
,
None
]:
"""Generate outputs for a request from a pooling model.
...
...
@@ -574,11 +505,6 @@ class MQLLMEngineClient(EngineClient):
The output `PoolingRequestOutput` objects from the LLMEngine
for the request.
"""
if
inputs
is
not
None
:
prompt
=
inputs
assert
(
prompt
is
not
None
and
pooling_params
is
not
None
and
request_id
is
not
None
)
return
cast
(
AsyncGenerator
[
PoolingRequestOutput
,
None
],
self
.
_process_request
(
prompt
,
...
...
vllm/engine/multiprocessing/engine.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pickle
import
signal
...
...
vllm/engine/output_processor/interfaces.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
typing
import
Callable
,
List
...
...
vllm/engine/output_processor/multi_step.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
from
typing
import
Callable
,
List
,
cast
...
...
vllm/engine/output_processor/single_step.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
List
...
...
vllm/engine/output_processor/stop_checker.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Callable
,
List
,
Optional
,
Tuple
...
...
@@ -81,7 +82,7 @@ class StopChecker:
return
# Check if the sequence has reached max_model_len.
if
seq
.
get_len
()
>
self
.
_get_max_model_len
(
lora_req
):
if
seq
.
get_len
()
>
=
self
.
_get_max_model_len
(
lora_req
):
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
...
...
vllm/engine/output_processor/util.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
List
from
typing
import
Sequence
as
GenericSequence
...
...
vllm/engine/protocol.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
from
abc
import
ABC
,
abstractmethod
...
...
@@ -65,6 +66,7 @@ class EngineClient(ABC):
prompt
:
PromptType
,
request_id
:
str
,
params
:
BeamSearchParams
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
beam_width
=
params
.
beam_width
...
...
@@ -106,27 +108,31 @@ class EngineClient(ABC):
cum_logprob
=
0
,
logprobs
=
[],
multi_modal_data
=
multi_modal_data
,
mm_processor_kwargs
=
mm_processor_kwargs
)
mm_processor_kwargs
=
mm_processor_kwargs
,
lora_request
=
lora_request
)
]
completed
=
[]
for
_
in
range
(
max_tokens
):
prompts_batch
=
[
prompts_batch
,
lora_req_batch
=
zip
(
*
[(
TokensPrompt
(
prompt_token_ids
=
beam
.
tokens
,
multi_modal_data
=
beam
.
multi_modal_data
,
mm_processor_kwargs
=
beam
.
mm_processor_kwargs
)
for
beam
in
all_beams
]
mm_processor_kwargs
=
beam
.
mm_processor_kwargs
)
,
beam
.
lora_request
,
)
for
beam
in
all_beams
])
tasks
=
[]
request_id
=
f
"beam_search-
{
random_uuid
()
}
"
for
i
,
individual_prompt
in
enumerate
(
prompts_batch
):
for
i
,
(
individual_prompt
,
lora_req
)
in
enumerate
(
zip
(
prompts_batch
,
lora_req_batch
)):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
task
=
asyncio
.
create_task
(
collect_from_async_generator
(
self
.
generate
(
individual_prompt
,
beam_search_params
,
request_id_item
)))
self
.
generate
(
individual_prompt
,
beam_search_params
,
request_id_item
,
lora_request
=
lora_req
)))
tasks
.
append
(
task
)
output
=
await
asyncio
.
gather
(
*
tasks
)
...
...
@@ -159,6 +165,7 @@ class EngineClient(ABC):
tokens
=
current_beam
.
tokens
+
[
token_id
],
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
lora_request
=
current_beam
.
lora_request
,
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
,
multi_modal_data
=
current_beam
.
...
...
vllm/entrypoints/api_server.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
...
...
@@ -16,6 +17,7 @@ from typing import Any, Optional
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
import
vllm.envs
as
envs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.launcher
import
serve_http
...
...
@@ -28,7 +30,6 @@ from vllm.version import __version__ as VLLM_VERSION
logger
=
init_logger
(
"vllm.entrypoints.api_server"
)
TIMEOUT_KEEP_ALIVE
=
5
# seconds.
app
=
FastAPI
()
engine
=
None
...
...
@@ -133,7 +134,7 @@ async def run_server(args: Namespace,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
args
.
log_level
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
,
timeout_keep_alive
=
envs
.
VLLM_HTTP_
TIMEOUT_KEEP_ALIVE
,
ssl_keyfile
=
args
.
ssl_keyfile
,
ssl_certfile
=
args
.
ssl_certfile
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
...
...
vllm/entrypoints/chat_utils.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
json
...
...
@@ -1252,7 +1253,7 @@ def apply_hf_chat_template(
# investigation.
logger
.
exception
(
"An error occurred in `transformers` while applying chat template"
)
raise
ValueError
from
e
raise
ValueError
(
str
(
e
))
from
e
def
apply_mistral_chat_template
(
tokenizer
:
MistralTokenizer
,
...
...
@@ -1281,7 +1282,7 @@ def apply_mistral_chat_template(
# We convert those assertion errors to ValueErrors so they can be
# are properly caught in the preprocessing_input step
except
(
AssertionError
,
MistralCommonException
)
as
e
:
raise
ValueError
from
e
raise
ValueError
(
str
(
e
))
from
e
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
...
...
@@ -1292,7 +1293,7 @@ def apply_mistral_chat_template(
logger
.
exception
(
"An error occurred in `mistral_common` while applying chat "
"template"
)
raise
ValueError
from
e
raise
ValueError
(
str
(
e
))
from
e
def
random_tool_call_id
()
->
str
:
return
f
"chatcmpl-tool-
{
random_uuid
()
}
"
vllm/entrypoints/cli/benchmark/base.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
from
vllm.entrypoints.cli.types
import
CLISubcommand
...
...
vllm/entrypoints/cli/benchmark/latency.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
from
vllm.benchmarks.latency
import
add_cli_args
,
main
...
...
vllm/entrypoints/cli/benchmark/main.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
vllm.entrypoints.cli.benchmark.latency
...
...
vllm/entrypoints/cli/benchmark/serve.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
from
vllm.benchmarks.serve
import
add_cli_args
,
main
...
...
vllm/entrypoints/cli/benchmark/throughput.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
from
vllm.benchmarks.throughput
import
add_cli_args
,
main
...
...
vllm/entrypoints/cli/collect_env.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
...
...
vllm/entrypoints/cli/main.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# The CLI entrypoint to vLLM.
import
signal
...
...
@@ -7,9 +8,10 @@ import sys
import
vllm.entrypoints.cli.benchmark.main
import
vllm.entrypoints.cli.collect_env
import
vllm.entrypoints.cli.openai
import
vllm.entrypoints.cli.run_batch
import
vllm.entrypoints.cli.serve
import
vllm.version
from
vllm.entrypoints.utils
import
VLLM_S
ERVE
_PARSER_EPILOG
,
cli_env_setup
from
vllm.entrypoints.utils
import
VLLM_S
UBCMD
_PARSER_EPILOG
,
cli_env_setup
from
vllm.utils
import
FlexibleArgumentParser
CMD_MODULES
=
[
...
...
@@ -17,6 +19,7 @@ CMD_MODULES = [
vllm
.
entrypoints
.
cli
.
serve
,
vllm
.
entrypoints
.
cli
.
benchmark
.
main
,
vllm
.
entrypoints
.
cli
.
collect_env
,
vllm
.
entrypoints
.
cli
.
run_batch
,
]
...
...
@@ -34,7 +37,7 @@ def main():
parser
=
FlexibleArgumentParser
(
description
=
"vLLM CLI"
,
epilog
=
VLLM_S
ERVE
_PARSER_EPILOG
,
epilog
=
VLLM_S
UBCMD
_PARSER_EPILOG
,
)
parser
.
add_argument
(
'-v'
,
'--version'
,
...
...
vllm/entrypoints/cli/openai.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Commands that act as an interactive OpenAI API client
import
argparse
...
...
vllm/entrypoints/cli/run_batch.py
0 → 100644
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
asyncio
from
prometheus_client
import
start_http_server
from
vllm.entrypoints.cli.types
import
CLISubcommand
from
vllm.entrypoints.logger
import
logger
from
vllm.entrypoints.openai.run_batch
import
main
as
run_batch_main
from
vllm.entrypoints.openai.run_batch
import
make_arg_parser
from
vllm.entrypoints.utils
import
(
VLLM_SUBCMD_PARSER_EPILOG
,
show_filtered_argument_or_group_from_help
)
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.version
import
__version__
as
VLLM_VERSION
class
RunBatchSubcommand
(
CLISubcommand
):
"""The `run-batch` subcommand for vLLM CLI."""
def
__init__
(
self
):
self
.
name
=
"run-batch"
super
().
__init__
()
@
staticmethod
def
cmd
(
args
:
argparse
.
Namespace
)
->
None
:
logger
.
info
(
"vLLM batch processing API version %s"
,
VLLM_VERSION
)
logger
.
info
(
"args: %s"
,
args
)
# Start the Prometheus metrics server.
# LLMEngine uses the Prometheus client
# to publish metrics at the /metrics endpoint.
if
args
.
enable_metrics
:
logger
.
info
(
"Prometheus metrics enabled"
)
start_http_server
(
port
=
args
.
port
,
addr
=
args
.
url
)
else
:
logger
.
info
(
"Prometheus metrics disabled"
)
asyncio
.
run
(
run_batch_main
(
args
))
def
subparser_init
(
self
,
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
run_batch_parser
=
subparsers
.
add_parser
(
"run-batch"
,
help
=
"Run batch prompts and write results to file."
,
description
=
(
"Run batch prompts using vLLM's OpenAI-compatible API.
\n
"
"Supports local or HTTP input/output files."
),
usage
=
"vllm run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model <model>"
,
)
run_batch_parser
=
make_arg_parser
(
run_batch_parser
)
show_filtered_argument_or_group_from_help
(
run_batch_parser
,
"run-batch"
)
run_batch_parser
.
epilog
=
VLLM_SUBCMD_PARSER_EPILOG
return
run_batch_parser
def
cmd_init
()
->
list
[
CLISubcommand
]:
return
[
RunBatchSubcommand
()]
vllm/entrypoints/cli/serve.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
os
import
signal
import
sys
import
uvloop
import
zmq
import
vllm.envs
as
envs
from
vllm
import
AsyncEngineArgs
from
vllm.entrypoints.cli.types
import
CLISubcommand
from
vllm.entrypoints.openai.api_server
import
run_server
from
vllm.entrypoints.openai.api_server
import
(
run_server
,
run_server_worker
,
setup_server
)
from
vllm.entrypoints.openai.cli_args
import
(
make_arg_parser
,
validate_parsed_serve_args
)
from
vllm.entrypoints.utils
import
(
VLLM_S
ERVE
_PARSER_EPILOG
,
from
vllm.entrypoints.utils
import
(
VLLM_S
UBCMD
_PARSER_EPILOG
,
show_filtered_argument_or_group_from_help
)
from
vllm.executor.multiproc_worker_utils
import
_add_prefix
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
get_tcp_uri
from
vllm.utils
import
FlexibleArgumentParser
,
get_tcp_uri
,
zmq_socket_ctx
from
vllm.v1.engine.coordinator
import
DPCoordinator
from
vllm.v1.engine.core
import
EngineCoreProc
from
vllm.v1.engine.core_client
import
CoreEngineProcManager
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.metrics.prometheus
import
setup_multiprocess_prometheus
from
vllm.v1.utils
import
(
APIServerProcessManager
,
CoreEngine
,
CoreEngineActorManager
,
EngineZmqAddresses
,
get_engine_client_zmq_addr
,
wait_for_completion_or_failure
,
wait_for_engine_startup
)
logger
=
init_logger
(
__name__
)
...
...
@@ -36,9 +49,12 @@ class ServeSubcommand(CLISubcommand):
if
hasattr
(
args
,
'model_tag'
)
and
args
.
model_tag
is
not
None
:
args
.
model
=
args
.
model_tag
if
args
.
headless
:
if
args
.
headless
or
args
.
api_server_count
<
1
:
run_headless
(
args
)
elif
args
.
api_server_count
>
1
:
run_multi_api_server
(
args
)
else
:
# Single API server (this process).
uvloop
.
run
(
run_server
(
args
))
def
validate
(
self
,
args
:
argparse
.
Namespace
)
->
None
:
...
...
@@ -69,6 +85,11 @@ class ServeSubcommand(CLISubcommand):
type
=
int
,
default
=
0
,
help
=
'Starting data parallel rank for secondary nodes.'
)
serve_parser
.
add_argument
(
'--api-server-count'
,
'-asc'
,
type
=
int
,
default
=
1
,
help
=
'How many API server processes to run.'
)
serve_parser
.
add_argument
(
"--config"
,
type
=
str
,
...
...
@@ -80,8 +101,8 @@ class ServeSubcommand(CLISubcommand):
)
serve_parser
=
make_arg_parser
(
serve_parser
)
show_filtered_argument_or_group_from_help
(
serve_parser
)
serve_parser
.
epilog
=
VLLM_S
ERVE
_PARSER_EPILOG
show_filtered_argument_or_group_from_help
(
serve_parser
,
"serve"
)
serve_parser
.
epilog
=
VLLM_S
UBCMD
_PARSER_EPILOG
return
serve_parser
...
...
@@ -91,23 +112,26 @@ def cmd_init() -> list[CLISubcommand]:
def
run_headless
(
args
:
argparse
.
Namespace
):
if
args
.
api_server_count
>
1
:
raise
ValueError
(
"api_server_count can't be set in headless mode"
)
# Create the EngineConfig.
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
usage_context
=
UsageContext
.
OPENAI_API_SERVER
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
=
usage_context
)
if
not
envs
.
VLLM_USE_V1
:
raise
Runtim
eError
(
"Headless mode is only supported for V1"
)
raise
Valu
eError
(
"Headless mode is only supported for V1"
)
parallel_config
=
vllm_config
.
parallel_config
local_engine_count
=
parallel_config
.
data_parallel_size_local
host
=
parallel_config
.
data_parallel_master_ip
port
=
engine_args
.
data_parallel_rpc_port
# add to config too
input
_address
=
get_tcp_uri
(
host
,
port
)
handshake
_address
=
get_tcp_uri
(
host
,
port
)
if
local_engine_count
<=
0
:
raise
Runtim
eError
(
"data_parallel_size_local must be > 0 in "
"headless mode"
)
raise
Valu
eError
(
"data_parallel_size_local must be > 0 in "
"headless mode"
)
# Catch SIGTERM and SIGINT to allow graceful shutdown.
def
signal_handler
(
signum
,
frame
):
...
...
@@ -119,7 +143,7 @@ def run_headless(args: argparse.Namespace):
logger
.
info
(
"Launching %d data parallel engine(s) in headless mode, "
"with head node address %s."
,
local_engine_count
,
input
_address
)
"with head node address %s."
,
local_engine_count
,
handshake
_address
)
# Create the engines.
engine_manager
=
CoreEngineProcManager
(
...
...
@@ -129,7 +153,7 @@ def run_headless(args: argparse.Namespace):
local_start_index
=
0
,
vllm_config
=
vllm_config
,
on_head_node
=
False
,
input_address
=
input
_address
,
handshake_address
=
handshake
_address
,
executor_class
=
Executor
.
get_class
(
vllm_config
),
log_stats
=
not
engine_args
.
disable_log_stats
,
)
...
...
@@ -139,3 +163,166 @@ def run_headless(args: argparse.Namespace):
finally
:
logger
.
info
(
"Shutting down."
)
engine_manager
.
close
()
def
run_multi_api_server
(
args
:
argparse
.
Namespace
):
assert
not
args
.
headless
num_api_servers
=
args
.
api_server_count
assert
num_api_servers
>
0
if
num_api_servers
>
1
:
setup_multiprocess_prometheus
()
listen_address
,
sock
=
setup_server
(
args
)
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
usage_context
=
UsageContext
.
OPENAI_API_SERVER
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
=
usage_context
)
model_config
=
vllm_config
.
model_config
if
num_api_servers
>
1
:
if
not
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"api_server_count > 1 is only supported for V1"
)
if
envs
.
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
raise
ValueError
(
"VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
"with api_server_count > 1"
)
if
model_config
.
is_multimodal_model
and
not
(
model_config
.
disable_mm_preprocessor_cache
):
logger
.
warning
(
"Multi-model preprocessor cache will be disabled for"
" api_server_count > 1"
)
model_config
.
disable_mm_preprocessor_cache
=
True
parallel_config
=
vllm_config
.
parallel_config
assert
parallel_config
.
data_parallel_rank
==
0
dp_size
=
parallel_config
.
data_parallel_size
local_engine_count
=
parallel_config
.
data_parallel_size_local
host
=
parallel_config
.
data_parallel_master_ip
local_only
=
local_engine_count
==
dp_size
# Set up input and output addresses.
input_addresses
=
[
get_engine_client_zmq_addr
(
local_only
,
host
)
for
_
in
range
(
num_api_servers
)
]
output_addresses
=
[
get_engine_client_zmq_addr
(
local_only
,
host
)
for
_
in
range
(
num_api_servers
)
]
addresses
=
EngineZmqAddresses
(
inputs
=
input_addresses
,
outputs
=
output_addresses
,
)
# Set up coordinator for dp > 1.
coordinator
=
None
stats_update_address
=
None
if
dp_size
>
1
:
coordinator
=
DPCoordinator
(
parallel_config
)
addresses
.
coordinator_input
,
addresses
.
coordinator_output
=
(
coordinator
.
get_engine_socket_addresses
())
stats_update_address
=
coordinator
.
get_stats_publish_address
()
logger
.
info
(
"Started DP Coordinator process (PID: %d)"
,
coordinator
.
proc
.
pid
)
if
parallel_config
.
data_parallel_backend
==
"ray"
:
logger
.
info
(
"Starting ray-based data parallel backend"
)
engine_actor_manager
=
CoreEngineActorManager
(
vllm_config
=
vllm_config
,
addresses
=
addresses
,
executor_class
=
Executor
.
get_class
(
vllm_config
),
log_stats
=
not
engine_args
.
disable_log_stats
,
)
# Start API servers using the manager
api_server_manager
=
APIServerProcessManager
(
target_server_fn
=
run_api_server_worker_proc
,
listen_address
=
listen_address
,
sock
=
sock
,
args
=
args
,
num_servers
=
num_api_servers
,
input_addresses
=
input_addresses
,
output_addresses
=
output_addresses
,
stats_update_address
=
stats_update_address
)
wait_for_completion_or_failure
(
api_server_manager
=
api_server_manager
,
engine_manager
=
engine_actor_manager
,
coordinator
=
coordinator
)
return
handshake_address
=
get_engine_client_zmq_addr
(
local_only
,
host
,
parallel_config
.
data_parallel_rpc_port
)
with
zmq_socket_ctx
(
handshake_address
,
zmq
.
ROUTER
,
bind
=
True
)
as
handshake_socket
:
# Start local engines.
if
not
local_engine_count
:
local_engine_manager
=
None
else
:
local_engine_manager
=
CoreEngineProcManager
(
EngineCoreProc
.
run_engine_core
,
vllm_config
=
vllm_config
,
executor_class
=
Executor
.
get_class
(
vllm_config
),
log_stats
=
not
engine_args
.
disable_log_stats
,
handshake_address
=
handshake_address
,
on_head_node
=
True
,
local_engine_count
=
local_engine_count
,
start_index
=
0
,
local_start_index
=
0
)
# Start API servers using the manager
api_server_manager
=
APIServerProcessManager
(
target_server_fn
=
run_api_server_worker_proc
,
listen_address
=
listen_address
,
sock
=
sock
,
args
=
args
,
num_servers
=
num_api_servers
,
input_addresses
=
input_addresses
,
output_addresses
=
output_addresses
,
stats_update_address
=
stats_update_address
)
# Wait for engine handshakes to complete.
core_engines
=
[
CoreEngine
(
index
=
i
,
local
=
(
i
<
local_engine_count
))
for
i
in
range
(
dp_size
)
]
wait_for_engine_startup
(
handshake_socket
,
addresses
,
core_engines
,
parallel_config
,
vllm_config
.
cache_config
,
local_engine_manager
,
coordinator
.
proc
if
coordinator
else
None
,
)
# Wait for API servers
wait_for_completion_or_failure
(
api_server_manager
=
api_server_manager
,
engine_manager
=
local_engine_manager
,
coordinator
=
coordinator
)
def
run_api_server_worker_proc
(
listen_address
,
sock
,
args
,
client_config
=
None
,
**
uvicorn_kwargs
)
->
None
:
"""Entrypoint for individual API server worker processes."""
# Add process-specific prefix to stdout and stderr.
from
multiprocessing
import
current_process
process_name
=
current_process
().
name
pid
=
os
.
getpid
()
_add_prefix
(
sys
.
stdout
,
process_name
,
pid
)
_add_prefix
(
sys
.
stderr
,
process_name
,
pid
)
uvloop
.
run
(
run_server_worker
(
listen_address
,
sock
,
args
,
client_config
,
**
uvicorn_kwargs
))
Prev
1
…
44
45
46
47
48
49
50
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment