Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e661d594
Commit
e661d594
authored
Aug 12, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1
parents
6b16ea2e
4db5176d
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1292 additions
and
285 deletions
+1292
-285
vllm/engine/metrics.py
vllm/engine/metrics.py
+31
-16
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+31
-8
vllm/engine/protocol.py
vllm/engine/protocol.py
+84
-0
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+53
-24
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+28
-8
vllm/entrypoints/launcher.py
vllm/entrypoints/launcher.py
+46
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+43
-1
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+109
-56
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+11
-0
vllm/entrypoints/openai/logits_processors.py
vllm/entrypoints/openai/logits_processors.py
+83
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+56
-54
vllm/entrypoints/openai/rpc/__init__.py
vllm/entrypoints/openai/rpc/__init__.py
+42
-0
vllm/entrypoints/openai/rpc/client.py
vllm/entrypoints/openai/rpc/client.py
+248
-0
vllm/entrypoints/openai/rpc/server.py
vllm/entrypoints/openai/rpc/server.py
+218
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+84
-45
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+35
-30
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+7
-6
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+26
-15
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+15
-14
vllm/envs.py
vllm/envs.py
+42
-8
No files found.
vllm/engine/metrics.py
View file @
e661d594
...
@@ -355,6 +355,7 @@ class StatLoggerBase(ABC):
...
@@ -355,6 +355,7 @@ class StatLoggerBase(ABC):
self
.
num_generation_tokens
:
List
[
int
]
=
[]
self
.
num_generation_tokens
:
List
[
int
]
=
[]
self
.
last_local_log
=
time
.
time
()
self
.
last_local_log
=
time
.
time
()
self
.
local_interval
=
local_interval
self
.
local_interval
=
local_interval
self
.
spec_decode_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
@
abstractmethod
@
abstractmethod
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
...
@@ -364,6 +365,12 @@ class StatLoggerBase(ABC):
...
@@ -364,6 +365,12 @@ class StatLoggerBase(ABC):
def
log
(
self
,
stats
:
Stats
)
->
None
:
def
log
(
self
,
stats
:
Stats
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
def
maybe_update_spec_decode_metrics
(
self
,
stats
:
Stats
):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if
stats
.
spec_decode_metrics
is
not
None
:
self
.
spec_decode_metrics
=
stats
.
spec_decode_metrics
class
LoggingStatLogger
(
StatLoggerBase
):
class
LoggingStatLogger
(
StatLoggerBase
):
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
...
@@ -379,6 +386,9 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -379,6 +386,9 @@ class LoggingStatLogger(StatLoggerBase):
self
.
num_prompt_tokens
.
append
(
stats
.
num_prompt_tokens_iter
)
self
.
num_prompt_tokens
.
append
(
stats
.
num_prompt_tokens_iter
)
self
.
num_generation_tokens
.
append
(
stats
.
num_generation_tokens_iter
)
self
.
num_generation_tokens
.
append
(
stats
.
num_generation_tokens_iter
)
# Update spec decode metrics
self
.
maybe_update_spec_decode_metrics
(
stats
)
# Log locally every local_interval seconds.
# Log locally every local_interval seconds.
if
local_interval_elapsed
(
stats
.
now
,
self
.
last_local_log
,
if
local_interval_elapsed
(
stats
.
now
,
self
.
last_local_log
,
self
.
local_interval
):
self
.
local_interval
):
...
@@ -408,15 +418,16 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -408,15 +418,16 @@ class LoggingStatLogger(StatLoggerBase):
stats
.
cpu_cache_usage_sys
*
100
,
stats
.
cpu_cache_usage_sys
*
100
,
)
)
if
self
.
spec_decode_metrics
is
not
None
:
logger
.
info
(
self
.
_format_spec_decode_metrics_str
(
self
.
spec_decode_metrics
))
# Reset tracked stats for next interval.
# Reset tracked stats for next interval.
self
.
num_prompt_tokens
=
[]
self
.
num_prompt_tokens
=
[]
self
.
num_generation_tokens
=
[]
self
.
num_generation_tokens
=
[]
self
.
last_local_log
=
stats
.
now
self
.
last_local_log
=
stats
.
now
self
.
spec_decode_metrics
=
None
if
stats
.
spec_decode_metrics
is
not
None
:
logger
.
info
(
self
.
_format_spec_decode_metrics_str
(
stats
.
spec_decode_metrics
))
def
_format_spec_decode_metrics_str
(
def
_format_spec_decode_metrics_str
(
self
,
metrics
:
"SpecDecodeWorkerMetrics"
)
->
str
:
self
,
metrics
:
"SpecDecodeWorkerMetrics"
)
->
str
:
...
@@ -533,6 +544,9 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -533,6 +544,9 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
num_prompt_tokens
.
append
(
stats
.
num_prompt_tokens_iter
)
self
.
num_prompt_tokens
.
append
(
stats
.
num_prompt_tokens_iter
)
self
.
num_generation_tokens
.
append
(
stats
.
num_generation_tokens_iter
)
self
.
num_generation_tokens
.
append
(
stats
.
num_generation_tokens_iter
)
# Update spec decode metrics
self
.
maybe_update_spec_decode_metrics
(
stats
)
# Log locally every local_interval seconds.
# Log locally every local_interval seconds.
if
local_interval_elapsed
(
stats
.
now
,
self
.
last_local_log
,
if
local_interval_elapsed
(
stats
.
now
,
self
.
last_local_log
,
self
.
local_interval
):
self
.
local_interval
):
...
@@ -550,26 +564,27 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -550,26 +564,27 @@ class PrometheusStatLogger(StatLoggerBase):
prompt_throughput
=
prompt_throughput
,
prompt_throughput
=
prompt_throughput
,
generation_throughput
=
generation_throughput
)
generation_throughput
=
generation_throughput
)
# Reset tracked stats for next interval.
if
self
.
spec_decode_metrics
is
not
None
:
self
.
num_prompt_tokens
=
[]
self
.
num_generation_tokens
=
[]
self
.
last_local_log
=
stats
.
now
if
stats
.
spec_decode_metrics
is
not
None
:
self
.
_log_gauge
(
self
.
_log_gauge
(
self
.
metrics
.
gauge_spec_decode_draft_acceptance_rate
,
self
.
metrics
.
gauge_spec_decode_draft_acceptance_rate
,
s
tats
.
spec_decode_metrics
.
draft_acceptance_rate
)
s
elf
.
spec_decode_metrics
.
draft_acceptance_rate
)
self
.
_log_gauge
(
self
.
metrics
.
gauge_spec_decode_efficiency
,
self
.
_log_gauge
(
self
.
metrics
.
gauge_spec_decode_efficiency
,
s
tats
.
spec_decode_metrics
.
system_efficiency
)
s
elf
.
spec_decode_metrics
.
system_efficiency
)
self
.
_log_counter
(
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_accepted_tokens
,
self
.
metrics
.
counter_spec_decode_num_accepted_tokens
,
s
tats
.
spec_decode_metrics
.
accepted_tokens
)
s
elf
.
spec_decode_metrics
.
accepted_tokens
)
self
.
_log_counter
(
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_draft_tokens
,
self
.
metrics
.
counter_spec_decode_num_draft_tokens
,
s
tats
.
spec_decode_metrics
.
draft_tokens
)
s
elf
.
spec_decode_metrics
.
draft_tokens
)
self
.
_log_counter
(
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_emitted_tokens
,
self
.
metrics
.
counter_spec_decode_num_emitted_tokens
,
stats
.
spec_decode_metrics
.
emitted_tokens
)
self
.
spec_decode_metrics
.
emitted_tokens
)
# Reset tracked stats for next interval.
self
.
num_prompt_tokens
=
[]
self
.
num_generation_tokens
=
[]
self
.
last_local_log
=
stats
.
now
self
.
spec_decode_metrics
=
None
class
RayPrometheusStatLogger
(
PrometheusStatLogger
):
class
RayPrometheusStatLogger
(
PrometheusStatLogger
):
...
...
vllm/engine/output_processor/single_step.py
View file @
e661d594
...
@@ -81,6 +81,29 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -81,6 +81,29 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
SequenceGroupOutput
)
->
None
:
outputs
:
SequenceGroupOutput
)
->
None
:
sampling_params
=
seq_group
.
sampling_params
if
sampling_params
.
n
==
1
and
not
sampling_params
.
use_beam_search
:
# only have one output sample
sample
=
outputs
.
samples
[
0
]
# only have one sequence
seq
=
seq_group
.
seqs
[
0
]
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
if
sampling_params
.
detokenize
and
self
.
detokenizer
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
sampling_params
)
else
:
new_char_count
=
0
self
.
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
,
sampling_params
,
lora_req
=
seq_group
.
lora_request
,
)
if
seq
.
is_finished
():
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_seq
(
seq
)
return
# Process samples
# Process samples
samples
=
outputs
.
samples
samples
=
outputs
.
samples
parent_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
parent_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
...
@@ -127,20 +150,20 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -127,20 +150,20 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
child_seqs
.
append
((
parent
,
parent
))
child_seqs
.
append
((
parent
,
parent
))
for
seq
,
_
in
child_seqs
:
for
seq
,
_
in
child_seqs
:
if
seq_group
.
sampling_params
.
detokenize
and
self
.
detokenizer
:
if
sampling_params
.
detokenize
and
self
.
detokenizer
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
seq_group
.
sampling_params
)
seq
,
sampling_params
)
else
:
else
:
new_char_count
=
0
new_char_count
=
0
self
.
stop_checker
.
maybe_stop_sequence
(
self
.
stop_checker
.
maybe_stop_sequence
(
seq
,
seq
,
new_char_count
,
new_char_count
,
seq_group
.
sampling_params
,
sampling_params
,
lora_req
=
seq_group
.
lora_request
,
lora_req
=
seq_group
.
lora_request
,
)
)
# Non-beam search case
# Non-beam search case
if
not
seq_group
.
sampling_params
.
use_beam_search
:
if
not
sampling_params
.
use_beam_search
:
# For newly created child sequences, add them to the sequence group
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
# and fork them in block manager if they are not finished.
for
seq
,
parent
in
child_seqs
:
for
seq
,
parent
in
child_seqs
:
...
@@ -164,8 +187,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -164,8 +187,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Select the child sequences to keep in the sequence group.
# Select the child sequences to keep in the sequence group.
selected_child_seqs
:
List
[
Tuple
[
Sequence
,
Optional
[
Sequence
]]]
=
[]
selected_child_seqs
:
List
[
Tuple
[
Sequence
,
Optional
[
Sequence
]]]
=
[]
unselected_child_seqs
:
List
[
Tuple
[
Sequence
,
Optional
[
Sequence
]]]
=
[]
unselected_child_seqs
:
List
[
Tuple
[
Sequence
,
Optional
[
Sequence
]]]
=
[]
beam_width
=
seq_group
.
sampling_params
.
best_of
beam_width
=
sampling_params
.
best_of
length_penalty
=
seq_group
.
sampling_params
.
length_penalty
length_penalty
=
sampling_params
.
length_penalty
# Select the newly finished sequences with the highest scores
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# to replace existing finished sequences.
...
@@ -219,8 +242,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -219,8 +242,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
best_running_seq
=
running_child_seqs
[
0
][
0
]
best_running_seq
=
running_child_seqs
[
0
][
0
]
current_worst_seq
=
all_finished_seqs
[
beam_width
-
1
][
0
]
current_worst_seq
=
all_finished_seqs
[
beam_width
-
1
][
0
]
stop_beam_search
=
self
.
_check_beam_search_early_stopping
(
stop_beam_search
=
self
.
_check_beam_search_early_stopping
(
seq_group
.
sampling_params
.
early_stopping
,
sampling_params
.
early_stopping
,
sampling_params
,
seq_group
.
sampling_params
,
best_running_seq
,
current_worst_seq
)
best_running_seq
,
current_worst_seq
)
if
stop_beam_search
:
if
stop_beam_search
:
# Stop the beam search and remove all the running sequences from
# Stop the beam search and remove all the running sequences from
...
...
vllm/engine/protocol.py
0 → 100644
View file @
e661d594
from
typing
import
(
AsyncIterator
,
List
,
Mapping
,
Optional
,
Protocol
,
runtime_checkable
)
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.inputs.data
import
PromptInputs
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SamplerOutput
@
runtime_checkable
class
AsyncEngineClient
(
Protocol
):
"""Protocol class for Clients to AsyncLLMEngine"""
@
property
def
is_running
(
self
)
->
bool
:
...
@
property
def
is_stopped
(
self
)
->
bool
:
...
@
property
def
errored
(
self
)
->
bool
:
...
async
def
generate
(
self
,
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
AsyncIterator
[
RequestOutput
]:
"""Generates outputs for a request"""
async
def
encode
(
self
,
inputs
:
PromptInputs
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
AsyncIterator
[
EmbeddingRequestOutput
]:
"""Generate outputs for a request from an embedding model."""
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
"""Abort a request.
Args:
request_id: The unique id of the request.
"""
async
def
get_model_config
(
self
)
->
ModelConfig
:
"""Get the model configuration of the vLLM engine."""
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
"""Get the decoding configuration of the vLLM engine."""
async
def
get_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
PreTrainedTokenizer
:
"""Get the appropriate Tokenizer for the request"""
async
def
is_tracing_enabled
(
self
)
->
bool
:
pass
async
def
do_log_stats
(
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
,
)
->
None
:
pass
async
def
check_health
(
self
)
->
None
:
"""Raise if unhealthy"""
vllm/entrypoints/api_server.py
View file @
e661d594
...
@@ -5,21 +5,23 @@ For production use, we recommend using our OpenAI compatible server.
...
@@ -5,21 +5,23 @@ For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
change `vllm/entrypoints/openai/api_server.py` instead.
"""
"""
import
asyncio
import
json
import
json
import
ssl
import
ssl
from
typing
import
AsyncGenerator
from
argparse
import
Namespace
from
typing
import
Any
,
AsyncGenerator
,
Optional
import
uvicorn
from
fastapi
import
FastAPI
,
Request
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.launcher
import
serve_http
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
random_uuid
from
vllm.utils
import
FlexibleArgumentParser
,
random_uuid
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
"vllm.entrypoints.api_server"
)
logger
=
init_logger
(
"vllm.entrypoints.api_server"
)
...
@@ -81,6 +83,53 @@ async def generate(request: Request) -> Response:
...
@@ -81,6 +83,53 @@ async def generate(request: Request) -> Response:
return
JSONResponse
(
ret
)
return
JSONResponse
(
ret
)
def
build_app
(
args
:
Namespace
)
->
FastAPI
:
global
app
app
.
root_path
=
args
.
root_path
return
app
async
def
init_app
(
args
:
Namespace
,
llm_engine
:
Optional
[
AsyncLLMEngine
]
=
None
,
)
->
FastAPI
:
app
=
build_app
(
args
)
global
engine
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
(
llm_engine
if
llm_engine
is
not
None
else
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
API_SERVER
))
return
app
async
def
run_server
(
args
:
Namespace
,
llm_engine
:
Optional
[
AsyncLLMEngine
]
=
None
,
**
uvicorn_kwargs
:
Any
)
->
None
:
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"args: %s"
,
args
)
app
=
await
init_app
(
args
,
llm_engine
)
shutdown_task
=
await
serve_http
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
args
.
log_level
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
,
ssl_keyfile
=
args
.
ssl_keyfile
,
ssl_certfile
=
args
.
ssl_certfile
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
ssl_cert_reqs
=
args
.
ssl_cert_reqs
,
**
uvicorn_kwargs
,
)
await
shutdown_task
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
None
)
...
@@ -105,25 +154,5 @@ if __name__ == "__main__":
...
@@ -105,25 +154,5 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--log-level"
,
type
=
str
,
default
=
"debug"
)
parser
.
add_argument
(
"--log-level"
,
type
=
str
,
default
=
"debug"
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
API_SERVER
)
app
.
root_path
=
args
.
root_path
logger
.
info
(
"Available routes are:"
)
asyncio
.
run
(
run_server
(
args
))
for
route
in
app
.
routes
:
if
not
hasattr
(
route
,
'methods'
):
continue
methods
=
', '
.
join
(
route
.
methods
)
logger
.
info
(
"Route: %s, Methods: %s"
,
route
.
path
,
methods
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
args
.
log_level
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
,
ssl_keyfile
=
args
.
ssl_keyfile
,
ssl_certfile
=
args
.
ssl_certfile
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
ssl_cert_reqs
=
args
.
ssl_cert_reqs
)
vllm/entrypoints/chat_utils.py
View file @
e661d594
import
codecs
import
codecs
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Awaitable
,
Iterable
,
List
,
Optional
,
Union
,
cast
,
final
from
typing
import
(
Awaitable
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
cast
,
final
)
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
...
@@ -65,8 +66,7 @@ class ConversationMessage(TypedDict):
...
@@ -65,8 +66,7 @@ class ConversationMessage(TypedDict):
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
ChatMessageParseResult
:
class
ChatMessageParseResult
:
messages
:
List
[
ConversationMessage
]
messages
:
List
[
ConversationMessage
]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
field
(
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
default_factory
=
list
)
def
load_chat_template
(
chat_template
:
Optional
[
str
])
->
Optional
[
str
]:
def
load_chat_template
(
chat_template
:
Optional
[
str
])
->
Optional
[
str
]:
...
@@ -100,14 +100,16 @@ def _image_token_str(model_config: ModelConfig,
...
@@ -100,14 +100,16 @@ def _image_token_str(model_config: ModelConfig,
if
model_type
==
"phi3_v"
:
if
model_type
==
"phi3_v"
:
# Workaround since this token is not defined in the tokenizer
# Workaround since this token is not defined in the tokenizer
return
"<|image_1|>"
return
"<|image_1|>"
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"minicpmv"
,
"paligemma"
):
if
model_type
==
"minicpmv"
:
return
"(<image>./</image>)"
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"paligemma"
):
# These models do not use image tokens in the prompt
# These models do not use image tokens in the prompt
return
None
return
None
if
model_type
.
startswith
(
"llava"
):
if
model_type
.
startswith
(
"llava"
):
return
tokenizer
.
decode
(
model_config
.
hf_config
.
image_token_index
)
return
tokenizer
.
decode
(
model_config
.
hf_config
.
image_token_index
)
if
model_type
==
"chameleon"
:
if
model_type
in
(
"chameleon"
,
"internvl_chat"
)
:
return
"<image>"
return
"<image>"
raise
TypeError
(
"Unknown model type: {model_type}"
)
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
# TODO: Let user specify how to insert image tokens into prompt
# TODO: Let user specify how to insert image tokens into prompt
...
@@ -172,7 +174,7 @@ def _parse_chat_message_content_parts(
...
@@ -172,7 +174,7 @@ def _parse_chat_message_content_parts(
return
ChatMessageParseResult
(
messages
=
messages
,
mm_futures
=
mm_futures
)
return
ChatMessageParseResult
(
messages
=
messages
,
mm_futures
=
mm_futures
)
def
parse_chat_message_content
(
def
_
parse_chat_message_content
(
message
:
ChatCompletionMessageParam
,
message
:
ChatCompletionMessageParam
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
tokenizer
:
PreTrainedTokenizer
,
tokenizer
:
PreTrainedTokenizer
,
...
@@ -188,3 +190,21 @@ def parse_chat_message_content(
...
@@ -188,3 +190,21 @@ def parse_chat_message_content(
return
_parse_chat_message_content_parts
(
role
,
content
,
model_config
,
return
_parse_chat_message_content_parts
(
role
,
content
,
model_config
,
tokenizer
)
tokenizer
)
def
parse_chat_messages
(
messages
:
List
[
ChatCompletionMessageParam
],
model_config
:
ModelConfig
,
tokenizer
:
PreTrainedTokenizer
,
)
->
Tuple
[
List
[
ConversationMessage
],
List
[
Awaitable
[
MultiModalDataDict
]]]:
conversation
:
List
[
ConversationMessage
]
=
[]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
[]
for
msg
in
messages
:
parse_result
=
_parse_chat_message_content
(
msg
,
model_config
,
tokenizer
)
conversation
.
extend
(
parse_result
.
messages
)
mm_futures
.
extend
(
parse_result
.
mm_futures
)
return
conversation
,
mm_futures
vllm/entrypoints/launcher.py
0 → 100644
View file @
e661d594
import
asyncio
import
signal
from
typing
import
Any
import
uvicorn
from
fastapi
import
FastAPI
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
async
def
serve_http
(
app
:
FastAPI
,
**
uvicorn_kwargs
:
Any
):
logger
.
info
(
"Available routes are:"
)
for
route
in
app
.
routes
:
methods
=
getattr
(
route
,
"methods"
,
None
)
path
=
getattr
(
route
,
"path"
,
None
)
if
methods
is
None
or
path
is
None
:
continue
logger
.
info
(
"Route: %s, Methods: %s"
,
path
,
', '
.
join
(
methods
))
config
=
uvicorn
.
Config
(
app
,
**
uvicorn_kwargs
)
server
=
uvicorn
.
Server
(
config
)
loop
=
asyncio
.
get_running_loop
()
server_task
=
loop
.
create_task
(
server
.
serve
())
def
signal_handler
()
->
None
:
# prevents the uvicorn signal handler to exit early
server_task
.
cancel
()
async
def
dummy_shutdown
()
->
None
:
pass
loop
.
add_signal_handler
(
signal
.
SIGINT
,
signal_handler
)
loop
.
add_signal_handler
(
signal
.
SIGTERM
,
signal_handler
)
try
:
await
server_task
return
dummy_shutdown
()
except
asyncio
.
CancelledError
:
logger
.
info
(
"Gracefully stopping http server"
)
return
server
.
shutdown
()
vllm/entrypoints/llm.py
View file @
e661d594
...
@@ -10,6 +10,9 @@ from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
...
@@ -10,6 +10,9 @@ from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
parse_and_batch_prompt
)
parse_and_batch_prompt
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
GuidedDecodingRequest
,
get_local_guided_decoding_logits_processor
)
from
vllm.model_executor.guided_decoding.guided_fields
import
LLMGuidedOptions
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
...
@@ -262,6 +265,8 @@ class LLM:
...
@@ -262,6 +265,8 @@ class LLM:
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
guided_options_request
:
Optional
[
Union
[
LLMGuidedOptions
,
GuidedDecodingRequest
]]
=
None
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
"""Generates the completions for the input prompts.
"""Generates the completions for the input prompts.
...
@@ -303,6 +308,14 @@ class LLM:
...
@@ -303,6 +308,14 @@ class LLM:
else
:
else
:
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
if
isinstance
(
guided_options_request
,
dict
):
if
len
(
guided_options_request
)
>
1
:
raise
ValueError
(
"You can only use one guided decoding but multiple is "
f
"specified:
{
guided_options_request
}
"
)
guided_options_request
=
GuidedDecodingRequest
(
**
guided_options_request
)
if
sampling_params
is
None
:
if
sampling_params
is
None
:
# Use default sampling params.
# Use default sampling params.
sampling_params
=
SamplingParams
()
sampling_params
=
SamplingParams
()
...
@@ -311,7 +324,8 @@ class LLM:
...
@@ -311,7 +324,8 @@ class LLM:
inputs
=
inputs
,
inputs
=
inputs
,
params
=
sampling_params
,
params
=
sampling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
guided_options
=
guided_options_request
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
LLMEngine
.
validate_outputs
(
outputs
,
RequestOutput
)
return
LLMEngine
.
validate_outputs
(
outputs
,
RequestOutput
)
...
@@ -508,6 +522,7 @@ class LLM:
...
@@ -508,6 +522,7 @@ class LLM:
Sequence
[
PoolingParams
]],
Sequence
[
PoolingParams
]],
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
,
)
->
None
:
)
->
None
:
if
isinstance
(
inputs
,
(
str
,
dict
)):
if
isinstance
(
inputs
,
(
str
,
dict
)):
# Convert a single prompt to a list.
# Convert a single prompt to a list.
...
@@ -523,6 +538,15 @@ class LLM:
...
@@ -523,6 +538,15 @@ class LLM:
raise
ValueError
(
"The lengths of prompts and lora_request "
raise
ValueError
(
"The lengths of prompts and lora_request "
"must be the same."
)
"must be the same."
)
if
isinstance
(
params
,
list
):
params
=
[
self
.
_add_guided_processor
(
param
,
guided_options
)
if
isinstance
(
param
,
SamplingParams
)
else
param
for
param
in
params
]
elif
isinstance
(
params
,
SamplingParams
):
params
=
self
.
_add_guided_processor
(
params
,
guided_options
)
# Add requests to the engine.
# Add requests to the engine.
for
i
,
request_inputs
in
enumerate
(
inputs
):
for
i
,
request_inputs
in
enumerate
(
inputs
):
self
.
_add_request
(
self
.
_add_request
(
...
@@ -548,6 +572,24 @@ class LLM:
...
@@ -548,6 +572,24 @@ class LLM:
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
)
def
_add_guided_processor
(
self
,
params
:
SamplingParams
,
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
):
if
guided_options
:
if
guided_options
.
guided_decoding_backend
is
None
:
decoding_config
=
self
.
llm_engine
.
get_decoding_config
()
guided_options
.
guided_decoding_backend
=
(
decoding_config
.
guided_decoding_backend
)
guided_logits_processor
=
get_local_guided_decoding_logits_processor
(
#noqa
guided_options
.
guided_decoding_backend
,
guided_options
,
self
.
get_tokenizer
())
if
guided_logits_processor
:
if
params
.
logits_processors
is
None
:
params
.
logits_processors
=
[]
params
.
logits_processors
.
append
(
guided_logits_processor
)
return
params
def
_run_engine
(
def
_run_engine
(
self
,
*
,
use_tqdm
:
bool
self
,
*
,
use_tqdm
:
bool
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
...
...
vllm/entrypoints/openai/api_server.py
View file @
e661d594
...
@@ -2,13 +2,13 @@ import asyncio
...
@@ -2,13 +2,13 @@ import asyncio
import
importlib
import
importlib
import
inspect
import
inspect
import
re
import
re
from
argparse
import
Namespace
from
contextlib
import
asynccontextmanager
from
contextlib
import
asynccontextmanager
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Optional
,
Set
from
multiprocessing
import
Process
from
typing
import
AsyncIterator
,
Set
import
fastapi
from
fastapi
import
APIRouter
,
FastAPI
,
Request
import
uvicorn
from
fastapi
import
APIRouter
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
...
@@ -16,8 +16,11 @@ from prometheus_client import make_asgi_app
...
@@ -16,8 +16,11 @@ from prometheus_client import make_asgi_app
from
starlette.routing
import
Mount
from
starlette.routing
import
Mount
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.launcher
import
serve_http
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
...
@@ -30,6 +33,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
...
@@ -30,6 +33,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingRequest
,
ErrorResponse
,
EmbeddingRequest
,
ErrorResponse
,
TokenizeRequest
,
TokenizeRequest
,
TokenizeResponse
)
TokenizeResponse
)
from
vllm.entrypoints.openai.rpc.client
import
AsyncEngineRPCClient
from
vllm.entrypoints.openai.rpc.server
import
run_rpc_server
# yapf: enable
# yapf: enable
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
...
@@ -38,12 +43,12 @@ from vllm.entrypoints.openai.serving_tokenization import (
...
@@ -38,12 +43,12 @@ from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization
)
OpenAIServingTokenization
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
,
get_open_port
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
TIMEOUT_KEEP_ALIVE
=
5
# seconds
TIMEOUT_KEEP_ALIVE
=
5
# seconds
engine
:
Async
LLM
Engine
async_engine_client
:
AsyncEngine
Client
engine_args
:
AsyncEngineArgs
engine_args
:
AsyncEngineArgs
openai_serving_chat
:
OpenAIServingChat
openai_serving_chat
:
OpenAIServingChat
openai_serving_completion
:
OpenAIServingCompletion
openai_serving_completion
:
OpenAIServingCompletion
...
@@ -55,13 +60,22 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
...
@@ -55,13 +60,22 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks
:
Set
[
asyncio
.
Task
]
=
set
()
_running_tasks
:
Set
[
asyncio
.
Task
]
=
set
()
def
model_is_embedding
(
model_name
:
str
,
trust_remote_code
:
bool
)
->
bool
:
return
ModelConfig
(
model
=
model_name
,
tokenizer
=
model_name
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
trust_remote_code
,
seed
=
0
,
dtype
=
"float16"
).
embedding_mode
@
asynccontextmanager
@
asynccontextmanager
async
def
lifespan
(
app
:
fastapi
.
FastAPI
):
async
def
lifespan
(
app
:
FastAPI
):
async
def
_force_log
():
async
def
_force_log
():
while
True
:
while
True
:
await
asyncio
.
sleep
(
10
)
await
asyncio
.
sleep
(
10
)
await
engine
.
do_log_stats
()
await
async_engine_client
.
do_log_stats
()
if
not
engine_args
.
disable_log_stats
:
if
not
engine_args
.
disable_log_stats
:
task
=
asyncio
.
create_task
(
_force_log
())
task
=
asyncio
.
create_task
(
_force_log
())
...
@@ -71,10 +85,56 @@ async def lifespan(app: fastapi.FastAPI):
...
@@ -71,10 +85,56 @@ async def lifespan(app: fastapi.FastAPI):
yield
yield
@
asynccontextmanager
async
def
build_async_engine_client
(
args
)
->
AsyncIterator
[
AsyncEngineClient
]:
# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
global
engine_args
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
# Backend itself still global for the silly lil' health handler
global
async_engine_client
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if
(
model_is_embedding
(
args
.
model
,
args
.
trust_remote_code
)
or
args
.
disable_frontend_multiprocessing
):
async_engine_client
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
)
yield
async_engine_client
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
else
:
# Start RPCServer in separate process (holds the AsyncLLMEngine).
port
=
get_open_port
(
envs
.
VLLM_RPC_PORT
)
rpc_server_process
=
Process
(
target
=
run_rpc_server
,
args
=
(
engine_args
,
UsageContext
.
OPENAI_API_SERVER
,
port
))
rpc_server_process
.
start
()
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client
=
AsyncEngineRPCClient
(
port
)
await
async_engine_client
.
setup
()
try
:
yield
async_engine_client
finally
:
# Ensure rpc server process was terminated
rpc_server_process
.
terminate
()
# Close all open connections to the backend
async_engine_client
.
close
()
# Wait for server process to join
rpc_server_process
.
join
()
router
=
APIRouter
()
router
=
APIRouter
()
def
mount_metrics
(
app
:
fastapi
.
FastAPI
):
def
mount_metrics
(
app
:
FastAPI
):
# Add prometheus asgi middleware to route /metrics requests
# Add prometheus asgi middleware to route /metrics requests
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
())
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
())
# Workaround for 307 Redirect for /metrics
# Workaround for 307 Redirect for /metrics
...
@@ -85,7 +145,7 @@ def mount_metrics(app: fastapi.FastAPI):
...
@@ -85,7 +145,7 @@ def mount_metrics(app: fastapi.FastAPI):
@
router
.
get
(
"/health"
)
@
router
.
get
(
"/health"
)
async
def
health
()
->
Response
:
async
def
health
()
->
Response
:
"""Health check."""
"""Health check."""
await
openai_serving_chat
.
engine
.
check_health
()
await
async_engine_client
.
check_health
()
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
...
@@ -164,8 +224,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
...
@@ -164,8 +224,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
return
JSONResponse
(
content
=
generator
.
model_dump
())
return
JSONResponse
(
content
=
generator
.
model_dump
())
def
build_app
(
args
)
:
def
build_app
(
args
:
Namespace
)
->
FastAPI
:
app
=
fastapi
.
FastAPI
(
lifespan
=
lifespan
)
app
=
FastAPI
(
lifespan
=
lifespan
)
app
.
include_router
(
router
)
app
.
include_router
(
router
)
app
.
root_path
=
args
.
root_path
app
.
root_path
=
args
.
root_path
...
@@ -213,37 +273,18 @@ def build_app(args):
...
@@ -213,37 +273,18 @@ def build_app(args):
return
app
return
app
def
run_server
(
args
,
llm_engine
=
None
):
async
def
init_app
(
async_engine_client
:
AsyncEngineClient
,
args
:
Namespace
,
)
->
FastAPI
:
app
=
build_app
(
args
)
app
=
build_app
(
args
)
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"args: %s"
,
args
)
if
args
.
served_model_name
is
not
None
:
if
args
.
served_model_name
is
not
None
:
served_model_names
=
args
.
served_model_name
served_model_names
=
args
.
served_model_name
else
:
else
:
served_model_names
=
[
args
.
model
]
served_model_names
=
[
args
.
model
]
global
engine
,
engine_args
model_config
=
await
async_engine_client
.
get_model_config
()
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
(
llm_engine
if
llm_engine
is
not
None
else
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
))
event_loop
:
Optional
[
asyncio
.
AbstractEventLoop
]
try
:
event_loop
=
asyncio
.
get_running_loop
()
except
RuntimeError
:
event_loop
=
None
if
event_loop
is
not
None
and
event_loop
.
is_running
():
# If the current is instanced by Ray Serve,
# there is already a running event loop
model_config
=
event_loop
.
run_until_complete
(
engine
.
get_model_config
())
else
:
# When using single vLLM without engine_use_ray
model_config
=
asyncio
.
run
(
engine
.
get_model_config
())
if
args
.
disable_log_requests
:
if
args
.
disable_log_requests
:
request_logger
=
None
request_logger
=
None
...
@@ -256,7 +297,7 @@ def run_server(args, llm_engine=None):
...
@@ -256,7 +297,7 @@ def run_server(args, llm_engine=None):
global
openai_serving_tokenization
global
openai_serving_tokenization
openai_serving_chat
=
OpenAIServingChat
(
openai_serving_chat
=
OpenAIServingChat
(
engine
,
async_engine_client
,
model_config
,
model_config
,
served_model_names
,
served_model_names
,
args
.
response_role
,
args
.
response_role
,
...
@@ -264,23 +305,25 @@ def run_server(args, llm_engine=None):
...
@@ -264,23 +305,25 @@ def run_server(args, llm_engine=None):
prompt_adapters
=
args
.
prompt_adapters
,
prompt_adapters
=
args
.
prompt_adapters
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template
=
args
.
chat_template
,
chat_template
=
args
.
chat_template
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
)
)
openai_serving_completion
=
OpenAIServingCompletion
(
openai_serving_completion
=
OpenAIServingCompletion
(
engine
,
async_engine_client
,
model_config
,
model_config
,
served_model_names
,
served_model_names
,
lora_modules
=
args
.
lora_modules
,
lora_modules
=
args
.
lora_modules
,
prompt_adapters
=
args
.
prompt_adapters
,
prompt_adapters
=
args
.
prompt_adapters
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
)
)
openai_serving_embedding
=
OpenAIServingEmbedding
(
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine
,
async_engine_client
,
model_config
,
model_config
,
served_model_names
,
served_model_names
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
)
)
openai_serving_tokenization
=
OpenAIServingTokenization
(
openai_serving_tokenization
=
OpenAIServingTokenization
(
engine
,
async_engine_client
,
model_config
,
model_config
,
served_model_names
,
served_model_names
,
lora_modules
=
args
.
lora_modules
,
lora_modules
=
args
.
lora_modules
,
...
@@ -289,22 +332,31 @@ def run_server(args, llm_engine=None):
...
@@ -289,22 +332,31 @@ def run_server(args, llm_engine=None):
)
)
app
.
root_path
=
args
.
root_path
app
.
root_path
=
args
.
root_path
logger
.
info
(
"Available routes are:"
)
return
app
for
route
in
app
.
routes
:
if
not
hasattr
(
route
,
'methods'
):
continue
async
def
run_server
(
args
,
**
uvicorn_kwargs
)
->
None
:
methods
=
', '
.
join
(
route
.
methods
)
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"
Route: %s, Methods: %s"
,
route
.
path
,
method
s
)
logger
.
info
(
"
args: %s"
,
arg
s
)
uvicorn
.
run
(
app
,
async
with
build_async_engine_client
(
args
)
as
async_engine_client
:
host
=
args
.
host
,
app
=
await
init_app
(
async_engine_client
,
args
)
port
=
args
.
port
,
log_level
=
args
.
uvicorn_log_level
,
shutdown_task
=
await
serve_http
(
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
,
app
,
ssl_keyfile
=
args
.
ssl_keyfile
,
host
=
args
.
host
,
ssl_certfile
=
args
.
ssl_certfile
,
port
=
args
.
port
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
log_level
=
args
.
uvicorn_log_level
,
ssl_cert_reqs
=
args
.
ssl_cert_reqs
)
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
,
ssl_keyfile
=
args
.
ssl_keyfile
,
ssl_certfile
=
args
.
ssl_certfile
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
ssl_cert_reqs
=
args
.
ssl_cert_reqs
,
**
uvicorn_kwargs
,
)
# NB: Await server shutdown only after the backend context is exited
await
shutdown_task
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -314,4 +366,5 @@ if __name__ == "__main__":
...
@@ -314,4 +366,5 @@ if __name__ == "__main__":
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
parser
=
make_arg_parser
(
parser
)
parser
=
make_arg_parser
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
run_server
(
args
)
asyncio
.
run
(
run_server
(
args
))
vllm/entrypoints/openai/cli_args.py
View file @
e661d594
...
@@ -128,6 +128,17 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
...
@@ -128,6 +128,17 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"using @app.middleware('http'). "
"using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). "
)
"using app.add_middleware(). "
)
parser
.
add_argument
(
"--return-tokens-as-token-ids"
,
action
=
"store_true"
,
help
=
"When --max-logprobs is specified, represents single tokens as "
"strings of the form 'token_id:{token_id}' so that tokens that "
"are not JSON-encodable can be identified."
)
parser
.
add_argument
(
"--disable-frontend-multiprocessing"
,
action
=
"store_true"
,
help
=
"If specified, will run the OpenAI frontend server in the same "
"process as the model serving engine."
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
...
...
vllm/entrypoints/openai/logits_processors.py
0 → 100644
View file @
e661d594
from
functools
import
lru_cache
,
partial
from
typing
import
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
,
Union
import
torch
from
transformers
import
PreTrainedTokenizer
from
vllm.sampling_params
import
LogitsProcessor
class
AllowedTokenIdsLogitsProcessor
:
"""Logits processor for constraining generated tokens to a
specific set of token ids."""
def
__init__
(
self
,
allowed_ids
:
Iterable
[
int
]):
self
.
allowed_ids
:
Optional
[
List
[
int
]]
=
list
(
allowed_ids
)
self
.
mask
:
Optional
[
torch
.
Tensor
]
=
None
def
__call__
(
self
,
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
mask
is
None
:
self
.
mask
=
torch
.
ones
((
logits
.
shape
[
-
1
],
),
dtype
=
torch
.
bool
,
device
=
logits
.
device
)
self
.
mask
[
self
.
allowed_ids
]
=
False
self
.
allowed_ids
=
None
logits
.
masked_fill_
(
self
.
mask
,
float
(
"-inf"
))
return
logits
@
lru_cache
(
maxsize
=
32
)
def
_get_allowed_token_ids_logits_processor
(
allowed_token_ids
:
FrozenSet
[
int
],
vocab_size
:
int
,
)
->
LogitsProcessor
:
if
not
allowed_token_ids
:
raise
ValueError
(
"Empty allowed_token_ids provided"
)
if
not
all
(
0
<=
tid
<
vocab_size
for
tid
in
allowed_token_ids
):
raise
ValueError
(
"allowed_token_ids contains "
"out-of-vocab token id"
)
return
AllowedTokenIdsLogitsProcessor
(
allowed_token_ids
)
def
logit_bias_logits_processor
(
logit_bias
:
Dict
[
str
,
float
],
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
token_id
,
bias
in
logit_bias
.
items
():
logits
[
token_id
]
+=
bias
return
logits
def
get_logits_processors
(
logit_bias
:
Optional
[
Union
[
Dict
[
int
,
float
],
Dict
[
str
,
float
]]],
allowed_token_ids
:
Optional
[
List
[
int
]],
tokenizer
:
PreTrainedTokenizer
)
->
List
[
LogitsProcessor
]:
logits_processors
=
[]
if
logit_bias
:
try
:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
clamped_logit_bias
:
Dict
[
int
,
float
]
=
{
int
(
token_id
):
min
(
100.0
,
max
(
-
100.0
,
bias
))
for
token_id
,
bias
in
logit_bias
.
items
()
}
except
ValueError
as
exc
:
raise
ValueError
(
"Found token_id in logit_bias that is not "
"an integer or string representing an integer"
)
from
exc
# Check if token_id is within the vocab size
for
token_id
,
bias
in
clamped_logit_bias
.
items
():
if
token_id
<
0
or
token_id
>=
tokenizer
.
vocab_size
:
raise
ValueError
(
"token_id in logit_bias contains "
"out-of-vocab token id"
)
logits_processors
.
append
(
partial
(
logit_bias_logits_processor
,
clamped_logit_bias
))
if
allowed_token_ids
is
not
None
:
logits_processors
.
append
(
_get_allowed_token_ids_logits_processor
(
frozenset
(
allowed_token_ids
),
tokenizer
.
vocab_size
))
return
logits_processors
vllm/entrypoints/openai/protocol.py
View file @
e661d594
# Adapted from
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import
time
import
time
from
argparse
import
Namespace
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
import
torch
import
torch
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
model_validator
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
model_validator
from
transformers
import
PreTrainedTokenizer
from
typing_extensions
import
Annotated
from
typing_extensions
import
Annotated
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.entrypoints.openai.logits_processors
import
get_logits_processors
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
# torch is mocked during docs generation,
# so we have to provide the values as literals
_MOCK_LONG_INFO
=
Namespace
(
min
=-
9223372036854775808
,
max
=
9223372036854775807
)
try
:
from
sphinx.ext.autodoc.mock
import
_MockModule
if
isinstance
(
torch
,
_MockModule
):
_LONG_INFO
=
_MOCK_LONG_INFO
else
:
_LONG_INFO
=
torch
.
iinfo
(
torch
.
long
)
except
ModuleNotFoundError
:
_LONG_INFO
=
torch
.
iinfo
(
torch
.
long
)
assert
_LONG_INFO
.
min
==
_MOCK_LONG_INFO
.
min
assert
_LONG_INFO
.
max
==
_MOCK_LONG_INFO
.
max
class
OpenAIBaseModel
(
BaseModel
):
class
OpenAIBaseModel
(
BaseModel
):
# OpenAI API does not allow extra fields
# OpenAI API does not allow extra fields
...
@@ -106,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -106,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
n
:
Optional
[
int
]
=
1
n
:
Optional
[
int
]
=
1
presence_penalty
:
Optional
[
float
]
=
0.0
presence_penalty
:
Optional
[
float
]
=
0.0
response_format
:
Optional
[
ResponseFormat
]
=
None
response_format
:
Optional
[
ResponseFormat
]
=
None
seed
:
Optional
[
int
]
=
Field
(
None
,
seed
:
Optional
[
int
]
=
Field
(
None
,
ge
=
_LONG_INFO
.
min
,
le
=
_LONG_INFO
.
max
)
ge
=
torch
.
iinfo
(
torch
.
long
).
min
,
le
=
torch
.
iinfo
(
torch
.
long
).
max
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
stream
:
Optional
[
bool
]
=
False
stream_options
:
Optional
[
StreamOptions
]
=
None
stream_options
:
Optional
[
StreamOptions
]
=
None
...
@@ -213,30 +231,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -213,30 +231,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params
# doc: end-chat-completion-extra-params
def
to_sampling_params
(
self
)
->
SamplingParams
:
def
to_sampling_params
(
# We now allow logprobs being true without top_logrobs.
self
,
tokenizer
:
PreTrainedTokenizer
,
guided_decode_logits_processor
:
Optional
[
LogitsProcessor
],
default_max_tokens
:
int
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
logits_processors
=
None
# We now allow logprobs being true without top_logrobs.
if
self
.
logit_bias
:
logits_processors
=
get_logits_processors
(
logit_bias
:
Dict
[
int
,
float
]
=
{}
logit_bias
=
self
.
logit_bias
,
try
:
allowed_token_ids
=
None
,
for
token_id
,
bias
in
self
.
logit_bias
.
items
():
tokenizer
=
tokenizer
,
# Convert token_id to integer before we add to LLMEngine
)
# Clamp the bias between -100 and 100 per OpenAI API spec
if
guided_decode_logits_processor
:
logit_bias
[
int
(
token_id
)]
=
min
(
100
,
max
(
-
100
,
bias
))
logits_processors
.
append
(
guided_decode_logits_processor
)
except
ValueError
as
exc
:
raise
ValueError
(
f
"Found token_id `
{
token_id
}
` in logit_bias "
f
"but token_id must be an integer or string "
f
"representing an integer"
)
from
exc
def
logit_bias_logits_processor
(
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
token_id
,
bias
in
logit_bias
.
items
():
logits
[
token_id
]
+=
bias
return
logits
logits_processors
=
[
logit_bias_logits_processor
]
return
SamplingParams
(
return
SamplingParams
(
n
=
self
.
n
,
n
=
self
.
n
,
...
@@ -254,7 +264,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -254,7 +264,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
logprobs
=
self
.
top_logprobs
if
self
.
logprobs
else
None
,
logprobs
=
self
.
top_logprobs
if
self
.
logprobs
else
None
,
prompt_logprobs
=
self
.
top_logprobs
if
self
.
echo
else
None
,
prompt_logprobs
=
self
.
top_logprobs
if
self
.
echo
else
None
,
ignore_eos
=
self
.
ignore_eos
,
ignore_eos
=
self
.
ignore_eos
,
max_tokens
=
self
.
max_tokens
,
max_tokens
=
max_tokens
,
min_tokens
=
self
.
min_tokens
,
min_tokens
=
self
.
min_tokens
,
use_beam_search
=
self
.
use_beam_search
,
use_beam_search
=
self
.
use_beam_search
,
early_stopping
=
self
.
early_stopping
,
early_stopping
=
self
.
early_stopping
,
...
@@ -333,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -333,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel):
max_tokens
:
Optional
[
int
]
=
16
max_tokens
:
Optional
[
int
]
=
16
n
:
int
=
1
n
:
int
=
1
presence_penalty
:
Optional
[
float
]
=
0.0
presence_penalty
:
Optional
[
float
]
=
0.0
seed
:
Optional
[
int
]
=
Field
(
None
,
seed
:
Optional
[
int
]
=
Field
(
None
,
ge
=
_LONG_INFO
.
min
,
le
=
_LONG_INFO
.
max
)
ge
=
torch
.
iinfo
(
torch
.
long
).
min
,
le
=
torch
.
iinfo
(
torch
.
long
).
max
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
stream
:
Optional
[
bool
]
=
False
stream_options
:
Optional
[
StreamOptions
]
=
None
stream_options
:
Optional
[
StreamOptions
]
=
None
...
@@ -358,6 +366,7 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -358,6 +366,7 @@ class CompletionRequest(OpenAIBaseModel):
skip_special_tokens
:
bool
=
True
skip_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
allowed_token_ids
:
Optional
[
List
[
int
]]
=
None
# doc: end-completion-sampling-params
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
# doc: begin-completion-extra-params
...
@@ -407,30 +416,23 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -407,30 +416,23 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params
# doc: end-completion-extra-params
def
to_sampling_params
(
self
):
def
to_sampling_params
(
self
,
tokenizer
:
PreTrainedTokenizer
,
guided_decode_logits_processor
:
Optional
[
LogitsProcessor
],
default_max_tokens
:
int
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
logits_processors
=
None
logits_processors
=
get_logits_processors
(
if
self
.
logit_bias
:
logit_bias
=
self
.
logit_bias
,
logit_bias
:
Dict
[
int
,
float
]
=
{}
allowed_token_ids
=
self
.
allowed_token_ids
,
try
:
tokenizer
=
tokenizer
,
for
token_id
,
bias
in
self
.
logit_bias
.
items
():
)
# Convert token_id to integer
if
guided_decode_logits_processor
:
# Clamp the bias between -100 and 100 per OpenAI API spec
logits_processors
.
append
(
guided_decode_logits_processor
)
logit_bias
[
int
(
token_id
)]
=
min
(
100
,
max
(
-
100
,
bias
))
except
ValueError
as
exc
:
raise
ValueError
(
f
"Found token_id `
{
token_id
}
` in logit_bias "
f
"but token_id must be an integer or string "
f
"representing an integer"
)
from
exc
def
logit_bias_logits_processor
(
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
token_id
,
bias
in
logit_bias
.
items
():
logits
[
token_id
]
+=
bias
return
logits
logits_processors
=
[
logit_bias_logits_processor
]
return
SamplingParams
(
return
SamplingParams
(
n
=
self
.
n
,
n
=
self
.
n
,
...
@@ -447,7 +449,7 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -447,7 +449,7 @@ class CompletionRequest(OpenAIBaseModel):
stop_token_ids
=
self
.
stop_token_ids
,
stop_token_ids
=
self
.
stop_token_ids
,
logprobs
=
self
.
logprobs
,
logprobs
=
self
.
logprobs
,
ignore_eos
=
self
.
ignore_eos
,
ignore_eos
=
self
.
ignore_eos
,
max_tokens
=
self
.
max_tokens
if
not
echo_without_generation
else
1
,
max_tokens
=
max_tokens
if
not
echo_without_generation
else
1
,
min_tokens
=
self
.
min_tokens
,
min_tokens
=
self
.
min_tokens
,
use_beam_search
=
self
.
use_beam_search
,
use_beam_search
=
self
.
use_beam_search
,
early_stopping
=
self
.
early_stopping
,
early_stopping
=
self
.
early_stopping
,
...
...
vllm/entrypoints/openai/rpc/__init__.py
0 → 100644
View file @
e661d594
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
Mapping
,
Optional
,
Union
from
vllm.inputs
import
PromptInputs
from
vllm.lora.request
import
LoRARequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
VLLM_RPC_HEALTHY_STR
=
"HEALTHY"
@
dataclass
class
RPCGenerateRequest
:
inputs
:
PromptInputs
sampling_params
:
SamplingParams
request_id
:
str
lora_request
:
Optional
[
LoRARequest
]
=
None
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
@
dataclass
class
RPCAbortRequest
:
request_id
:
str
class
RPCUtilityRequest
(
Enum
):
IS_SERVER_READY
=
1
GET_MODEL_CONFIG
=
2
GET_DECODING_CONFIG
=
3
GET_PARALLEL_CONFIG
=
4
GET_SCHEDULER_CONFIG
=
5
GET_LORA_CONFIG
=
6
DO_LOG_STATS
=
7
CHECK_HEALTH
=
8
IS_TRACING_ENABLED
=
9
RPC_REQUEST_TYPE
=
Union
[
RPCGenerateRequest
,
RPCAbortRequest
,
RPCUtilityRequest
]
vllm/entrypoints/openai/rpc/client.py
0 → 100644
View file @
e661d594
from
contextlib
import
contextmanager
from
typing
import
Any
,
AsyncIterator
,
Mapping
,
Optional
import
cloudpickle
import
zmq
import
zmq.asyncio
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.entrypoints.openai.rpc
import
(
RPC_REQUEST_TYPE
,
VLLM_RPC_HEALTHY_STR
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCGenerateRequest
,
RPCUtilityRequest
)
from
vllm.inputs
import
PromptInputs
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
class
AsyncEngineRPCClient
:
def
__init__
(
self
,
port
:
int
):
self
.
context
=
zmq
.
asyncio
.
Context
()
self
.
path
=
f
"tcp://localhost:
{
port
}
"
async
def
setup
(
self
):
"""Setup the client before it starts sending server requests."""
# Wait until server is ready.
await
self
.
wait_for_server
()
# Get the configs.
self
.
model_config
=
await
self
.
_get_model_config_rpc
()
self
.
decoding_config
=
await
self
.
_get_decoding_config_rpc
()
self
.
tracing_flag
=
await
self
.
_is_tracing_enabled_rpc
()
# Create the tokenizer group.
# TODO: refactor OAI server to avoid needing this info.
self
.
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
self
.
model_config
,
scheduler_config
=
(
await
self
.
_get_scheduler_config_rpc
()),
parallel_config
=
(
await
self
.
_get_parallel_config_rpc
()),
enable_lora
=
bool
(
await
self
.
_get_lora_config_rpc
()),
)
def
close
(
self
):
"""Destroy the ZeroMQ Context."""
self
.
context
.
destroy
()
@
contextmanager
def
socket
(
self
):
# Ensure client sockets are always closed after use
# Connect to RPC socket for Request-Reply pattern,
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
try
:
socket
.
connect
(
self
.
path
)
yield
socket
finally
:
socket
.
close
()
async
def
_send_get_data_rpc_request
(
self
,
request
:
RPCUtilityRequest
,
expected_type
:
Any
,
error_message
:
str
)
->
Any
:
"""Send an RPC request that is expecting data back."""
with
self
.
socket
()
as
socket
:
# Ping RPCServer with a request.
await
socket
.
send
(
cloudpickle
.
dumps
(
request
))
# Await the data from the Server.
data
=
cloudpickle
.
loads
(
await
socket
.
recv
())
if
not
isinstance
(
data
,
expected_type
):
# LoRAConfig can be None.
if
expected_type
==
LoRAConfig
and
data
is
None
:
pass
else
:
raise
ValueError
(
error_message
)
return
data
async
def
_send_one_way_rpc_request
(
self
,
request
:
RPC_REQUEST_TYPE
,
error_message
:
str
):
"""Send one-way RPC request to trigger an action."""
with
self
.
socket
()
as
socket
:
# Ping RPC Server with request.
await
socket
.
send
(
cloudpickle
.
dumps
(
request
))
# Await acknowledgement from RPCServer.
response
=
cloudpickle
.
loads
(
await
socket
.
recv
())
if
not
isinstance
(
response
,
str
)
or
response
!=
VLLM_RPC_SUCCESS_STR
:
raise
ValueError
(
error_message
)
return
response
async
def
get_tokenizer
(
self
,
lora_request
:
LoRARequest
):
return
await
self
.
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
return
self
.
decoding_config
async
def
get_model_config
(
self
)
->
ModelConfig
:
return
self
.
model_config
async
def
is_tracing_enabled
(
self
)
->
bool
:
return
self
.
tracing_flag
async
def
wait_for_server
(
self
):
"""Wait for the RPCServer to start up."""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
IS_SERVER_READY
,
error_message
=
"Unable to start RPC Server."
)
async
def
_get_model_config_rpc
(
self
)
->
ModelConfig
:
"""Get the ModelConfig object from the RPC Server"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_MODEL_CONFIG
,
expected_type
=
ModelConfig
,
error_message
=
"Could not get ModelConfig from RPC Server"
)
async
def
_get_decoding_config_rpc
(
self
)
->
DecodingConfig
:
"""Get DecodingConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_DECODING_CONFIG
,
expected_type
=
DecodingConfig
,
error_message
=
"Could not get DecodingConfig from RPC Server"
)
async
def
_get_parallel_config_rpc
(
self
)
->
ParallelConfig
:
"""Get ParallelConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_PARALLEL_CONFIG
,
expected_type
=
ParallelConfig
,
error_message
=
"Could not get ParallelConfig from RPC Server"
)
async
def
_get_scheduler_config_rpc
(
self
)
->
SchedulerConfig
:
"""Get SchedulerConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_SCHEDULER_CONFIG
,
expected_type
=
SchedulerConfig
,
error_message
=
"Could not get SchedulerConfig from RPC Server"
)
async
def
_get_lora_config_rpc
(
self
):
"""Get LoRAConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_LORA_CONFIG
,
expected_type
=
LoRAConfig
,
error_message
=
"Could not get LoRAConfig from RPC Server"
)
async
def
_is_tracing_enabled_rpc
(
self
)
->
ParallelConfig
:
"""Get is_tracing_enabled flag from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
IS_TRACING_ENABLED
,
expected_type
=
bool
,
error_message
=
"Could not get is_tracing_enabled flag from RPC "
"Server"
)
async
def
abort
(
self
,
request_id
:
str
):
"""Send an ABORT_REQUEST signal to the RPC Server"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCAbortRequest
(
request_id
),
error_message
=
f
"RPCAbortRequest
{
request_id
}
failed"
)
async
def
do_log_stats
(
self
):
"""Send a DO_LOG_STATS signal to the RPC Server"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
DO_LOG_STATS
,
error_message
=
"RPCRequest DO_LOG_STATS failed."
)
async
def
generate
(
self
,
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
AsyncIterator
[
RequestOutput
]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
with
self
.
socket
()
as
socket
:
# Send RPCGenerateRequest to the RPCServer.
await
socket
.
send_multipart
([
cloudpickle
.
dumps
(
RPCGenerateRequest
(
inputs
=
inputs
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
))
])
# Stream back the results from the RPC Server.
while
True
:
message
=
await
socket
.
recv
()
request_output
=
cloudpickle
.
loads
(
message
)
if
isinstance
(
request_output
,
Exception
):
raise
request_output
if
request_output
.
finished
:
break
yield
request_output
yield
request_output
async
def
check_health
(
self
)
->
None
:
"""Raise if unhealthy"""
with
self
.
socket
()
as
socket
:
# Ping RPCServer with CHECK_HEALTH request.
await
socket
.
send
(
cloudpickle
.
dumps
(
RPCUtilityRequest
.
CHECK_HEALTH
)
)
# Await the reply from the server.
# TODO: do we need an internal timeout here?
# Or do we expect the external probe to timeout and let this chill?
health_message
=
cloudpickle
.
loads
(
await
socket
.
recv
())
if
isinstance
(
health_message
,
Exception
):
raise
health_message
if
health_message
!=
VLLM_RPC_HEALTHY_STR
:
raise
ValueError
(
"Expected healthy response from backend but got "
"f{health_message}"
)
async
def
encode
(
self
,
*
args
,
**
kwargs
)
->
AsyncIterator
[
EmbeddingRequestOutput
]:
raise
NotImplementedError
(
"Embeddings not supported with multiprocessing backend"
)
vllm/entrypoints/openai/rpc/server.py
0 → 100644
View file @
e661d594
import
asyncio
import
signal
from
typing
import
Any
,
Coroutine
import
cloudpickle
import
zmq
import
zmq.asyncio
from
typing_extensions
import
Never
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
from
vllm.entrypoints.openai.rpc
import
(
VLLM_RPC_HEALTHY_STR
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCGenerateRequest
,
RPCUtilityRequest
)
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
logger
=
init_logger
(
__name__
)
class
AsyncEngineRPCServer
:
def
__init__
(
self
,
async_engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
port
:
int
):
# Initialize engine first.
self
.
engine
=
AsyncLLMEngine
.
from_engine_args
(
async_engine_args
,
usage_context
)
# Initialize context.
self
.
context
=
zmq
.
asyncio
.
Context
()
# Init socket for readiness state.
self
.
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
ROUTER
)
# Note numeric form of localhost should be used for zmq bind(),
# see https://stackoverflow.com/a/8958414
self
.
socket
.
bind
(
f
"tcp://127.0.0.1:
{
port
}
"
)
def
cleanup
(
self
):
"""Cleanup all resources."""
self
.
socket
.
close
()
self
.
context
.
destroy
()
async
def
get_model_config
(
self
,
identity
):
"""Send the ModelConfig"""
model_config
=
await
self
.
engine
.
get_model_config
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
model_config
)])
async
def
get_decoding_config
(
self
,
identity
):
"""Send the DecodingConfig"""
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
decoding_config
)])
async
def
get_lora_config
(
self
,
identity
):
lora_config
=
await
self
.
engine
.
get_lora_config
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
lora_config
)])
async
def
get_scheduler_config
(
self
,
identity
):
"""Send the SchedulerConfig"""
parallel_config
=
await
self
.
engine
.
get_scheduler_config
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
parallel_config
)])
async
def
get_parallel_config
(
self
,
identity
):
"""Send the ParallelConfig"""
parallel_config
=
await
self
.
engine
.
get_parallel_config
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
parallel_config
)])
async
def
is_tracing_enabled
(
self
,
identity
):
"""Send the is_tracing_enabled flag"""
tracing_flag
=
await
self
.
engine
.
is_tracing_enabled
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
tracing_flag
)])
async
def
do_log_stats
(
self
,
identity
):
"""Log stats and confirm success."""
await
self
.
engine
.
do_log_stats
()
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
])
async
def
is_server_ready
(
self
,
identity
):
"""Notify the client that we are ready."""
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
])
async
def
abort
(
self
,
identity
,
request
:
RPCAbortRequest
):
"""Abort request and notify the client of success."""
# Abort the request in the llm engine.
await
self
.
engine
.
abort
(
request
.
request_id
)
# Send confirmation to the client.
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
])
async
def
generate
(
self
,
identity
,
generate_request
:
RPCGenerateRequest
):
try
:
results_generator
=
self
.
engine
.
generate
(
generate_request
.
inputs
,
sampling_params
=
generate_request
.
sampling_params
,
request_id
=
generate_request
.
request_id
,
lora_request
=
generate_request
.
lora_request
,
trace_headers
=
generate_request
.
trace_headers
,
prompt_adapter_request
=
generate_request
.
prompt_adapter_request
)
async
for
request_output
in
results_generator
:
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
request_output
)])
except
Exception
as
e
:
### Notify client of all failures
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
e
)])
async
def
check_health
(
self
,
identity
):
try
:
await
self
.
engine
.
check_health
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_HEALTHY_STR
)])
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
e
)])
def
_make_handler_coro
(
self
,
identity
,
message
)
->
Coroutine
[
Any
,
Any
,
Never
]:
"""Route the zmq message to the handler coroutine."""
request
=
cloudpickle
.
loads
(
message
)
if
isinstance
(
request
,
RPCGenerateRequest
):
return
self
.
generate
(
identity
,
request
)
elif
isinstance
(
request
,
RPCAbortRequest
):
return
self
.
abort
(
identity
,
request
)
elif
isinstance
(
request
,
RPCUtilityRequest
):
if
request
==
RPCUtilityRequest
.
GET_MODEL_CONFIG
:
return
self
.
get_model_config
(
identity
)
elif
request
==
RPCUtilityRequest
.
GET_PARALLEL_CONFIG
:
return
self
.
get_parallel_config
(
identity
)
elif
request
==
RPCUtilityRequest
.
GET_DECODING_CONFIG
:
return
self
.
get_decoding_config
(
identity
)
elif
request
==
RPCUtilityRequest
.
GET_SCHEDULER_CONFIG
:
return
self
.
get_scheduler_config
(
identity
)
elif
request
==
RPCUtilityRequest
.
GET_LORA_CONFIG
:
return
self
.
get_lora_config
(
identity
)
elif
request
==
RPCUtilityRequest
.
DO_LOG_STATS
:
return
self
.
do_log_stats
(
identity
)
elif
request
==
RPCUtilityRequest
.
IS_SERVER_READY
:
return
self
.
is_server_ready
(
identity
)
elif
request
==
RPCUtilityRequest
.
CHECK_HEALTH
:
return
self
.
check_health
(
identity
)
elif
request
==
RPCUtilityRequest
.
IS_TRACING_ENABLED
:
return
self
.
is_tracing_enabled
(
identity
)
else
:
raise
ValueError
(
f
"Unknown RPCUtilityRequest type:
{
request
}
"
)
else
:
raise
ValueError
(
f
"Unknown RPCRequest type:
{
request
}
"
)
async
def
run_server_loop
(
self
):
"""Inner RPC Server Loop"""
running_tasks
=
set
()
while
True
:
# Wait for a request.
identity
,
message
=
await
self
.
socket
.
recv_multipart
()
# Process the request async.
task
=
asyncio
.
create_task
(
self
.
_make_handler_coro
(
identity
,
message
))
# We need to keep around a strong reference to the task,
# to avoid the task disappearing mid-execution as running tasks
# can be GC'ed. Below is a common "fire-and-forget" tasks
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
running_tasks
.
add
(
task
)
task
.
add_done_callback
(
running_tasks
.
discard
)
async
def
run_server
(
server
:
AsyncEngineRPCServer
):
# Put the server task into the asyncio loop.
loop
=
asyncio
.
get_running_loop
()
server_task
=
loop
.
create_task
(
server
.
run_server_loop
())
# Interruption handling.
def
signal_handler
()
->
None
:
# Kill the server on interrupt / terminate
server_task
.
cancel
()
loop
.
add_signal_handler
(
signal
.
SIGINT
,
signal_handler
)
loop
.
add_signal_handler
(
signal
.
SIGTERM
,
signal_handler
)
try
:
await
server_task
except
asyncio
.
CancelledError
:
logger
.
info
(
"vLLM ZMQ RPC Server was interrupted."
)
finally
:
# Clean up all resources.
server
.
cleanup
()
def
run_rpc_server
(
async_engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
port
:
int
):
server
=
AsyncEngineRPCServer
(
async_engine_args
,
usage_context
,
port
)
asyncio
.
run
(
run_server
(
server
))
vllm/entrypoints/openai/serving_chat.py
View file @
e661d594
import
time
import
time
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Awaitable
,
Dict
,
List
,
from
typing
import
AsyncGenerator
,
AsyncIterator
,
Dict
,
List
,
Optional
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
from
typing
import
Union
...
@@ -8,10 +7,10 @@ from fastapi import Request
...
@@ -8,10 +7,10 @@ from fastapi import Request
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.
async_llm_engine
import
Async
LLM
Engine
from
vllm.engine.
protocol
import
AsyncEngine
Client
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
load_chat_template
,
load_chat_template
,
parse_chat_message
_content
)
parse_chat_message
s
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionLogProb
,
ChatCompletionLogProbs
,
ChatCompletionLogProb
,
ChatCompletionLogProbs
,
...
@@ -25,8 +24,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
...
@@ -25,8 +24,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath
)
PromptAdapterPath
)
from
vllm.inputs
import
PromptInputs
from
vllm.inputs
import
PromptInputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
...
@@ -41,7 +38,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -41,7 +38,7 @@ class OpenAIServingChat(OpenAIServing):
def
__init__
(
def
__init__
(
self
,
self
,
engine
:
Async
LLM
Engine
,
async_engine_client
:
AsyncEngine
Client
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
served_model_names
:
List
[
str
],
response_role
:
str
,
response_role
:
str
,
...
@@ -50,13 +47,15 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -50,13 +47,15 @@ class OpenAIServingChat(OpenAIServing):
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
request_logger
:
Optional
[
RequestLogger
],
request_logger
:
Optional
[
RequestLogger
],
chat_template
:
Optional
[
str
],
chat_template
:
Optional
[
str
],
return_tokens_as_token_ids
:
bool
=
False
,
):
):
super
().
__init__
(
engine
=
engine
,
super
().
__init__
(
async_engine_client
=
async_engine_client
,
model_config
=
model_config
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
served_model_names
=
served_model_names
,
lora_modules
=
lora_modules
,
lora_modules
=
lora_modules
,
prompt_adapters
=
prompt_adapters
,
prompt_adapters
=
prompt_adapters
,
request_logger
=
request_logger
)
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
return_tokens_as_token_ids
)
self
.
response_role
=
response_role
self
.
response_role
=
response_role
...
@@ -89,17 +88,11 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -89,17 +88,11 @@ class OpenAIServingChat(OpenAIServing):
)
=
self
.
_maybe_get_adapters
(
request
)
)
=
self
.
_maybe_get_adapters
(
request
)
model_config
=
self
.
model_config
model_config
=
self
.
model_config
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
conversation
:
List
[
ConversationMessage
]
=
[]
conversation
,
mm_futures
=
parse_chat_messages
(
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
[]
request
.
messages
,
model_config
,
tokenizer
)
for
msg
in
request
.
messages
:
chat_parsed_result
=
parse_chat_message_content
(
msg
,
model_config
,
tokenizer
)
conversation
.
extend
(
chat_parsed_result
.
messages
)
mm_futures
.
extend
(
chat_parsed_result
.
mm_futures
)
tool_dicts
=
None
if
request
.
tools
is
None
else
[
tool_dicts
=
None
if
request
.
tools
is
None
else
[
tool
.
model_dump
()
for
tool
in
request
.
tools
tool
.
model_dump
()
for
tool
in
request
.
tools
...
@@ -114,6 +107,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -114,6 +107,7 @@ class OpenAIServingChat(OpenAIServing):
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
**
(
request
.
chat_template_kwargs
or
{}),
**
(
request
.
chat_template_kwargs
or
{}),
)
)
assert
isinstance
(
prompt
,
str
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
"Error in applying chat template from request: %s"
,
e
)
logger
.
error
(
"Error in applying chat template from request: %s"
,
e
)
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
...
@@ -132,28 +126,23 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -132,28 +126,23 @@ class OpenAIServingChat(OpenAIServing):
request_id
=
f
"chat-
{
random_uuid
()
}
"
request_id
=
f
"chat-
{
random_uuid
()
}
"
try
:
try
:
sampling_params
=
request
.
to_sampling_params
()
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
guided_decode_logits_processor
=
(
guided_decode_logits_processor
=
(
await
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
get_guided_decoding_logits_processor
(
guided_decoding_backend
,
request
,
tokenizer
))
if
guided_decode_logits_processor
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
sampling_params
.
logits_processors
.
append
(
guided_decode_logits_processor
)
prompt_inputs
=
self
.
_tokenize_prompt_input
(
prompt_inputs
=
self
.
_tokenize_prompt_input
(
request
,
request
,
tokenizer
,
tokenizer
,
prompt
,
prompt
,
truncate_prompt_tokens
=
sampling_params
.
truncate_prompt_tokens
,
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
)
)
sampling_params
=
request
.
to_sampling_params
(
tokenizer
,
guided_decode_logits_processor
,
default_max_tokens
=
self
.
max_model_len
-
len
(
prompt_inputs
[
"prompt_token_ids"
]))
self
.
_log_inputs
(
request_id
,
self
.
_log_inputs
(
request_id
,
prompt_inputs
,
prompt_inputs
,
params
=
sampling_params
,
params
=
sampling_params
,
...
@@ -166,7 +155,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -166,7 +155,8 @@ class OpenAIServingChat(OpenAIServing):
if
mm_data
is
not
None
:
if
mm_data
is
not
None
:
engine_inputs
[
"multi_modal_data"
]
=
mm_data
engine_inputs
[
"multi_modal_data"
]
=
mm_data
is_tracing_enabled
=
await
self
.
engine
.
is_tracing_enabled
()
is_tracing_enabled
=
(
await
self
.
async_engine_client
.
is_tracing_enabled
())
trace_headers
=
None
trace_headers
=
None
if
is_tracing_enabled
and
raw_request
:
if
is_tracing_enabled
and
raw_request
:
trace_headers
=
extract_trace_headers
(
raw_request
.
headers
)
trace_headers
=
extract_trace_headers
(
raw_request
.
headers
)
...
@@ -174,7 +164,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -174,7 +164,7 @@ class OpenAIServingChat(OpenAIServing):
and
contains_trace_headers
(
raw_request
.
headers
)):
and
contains_trace_headers
(
raw_request
.
headers
)):
log_tracing_disabled_warning
()
log_tracing_disabled_warning
()
result_generator
=
self
.
engine
.
generate
(
result_generator
=
self
.
async_engine_client
.
generate
(
engine_inputs
,
engine_inputs
,
sampling_params
,
sampling_params
,
request_id
,
request_id
,
...
@@ -247,7 +237,15 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -247,7 +237,15 @@ class OpenAIServingChat(OpenAIServing):
model
=
model_name
)
model
=
model_name
)
if
(
request
.
stream_options
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
and
request
.
stream_options
.
include_usage
):
chunk
.
usage
=
None
if
(
request
.
stream_options
.
continuous_usage_stats
):
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
0
,
total_tokens
=
prompt_tokens
)
chunk
.
usage
=
usage
else
:
chunk
.
usage
=
None
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
yield
f
"data:
{
data
}
\n\n
"
...
@@ -277,7 +275,18 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -277,7 +275,18 @@ class OpenAIServingChat(OpenAIServing):
model
=
model_name
)
model
=
model_name
)
if
(
request
.
stream_options
and
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
request
.
stream_options
.
include_usage
):
chunk
.
usage
=
None
if
(
request
.
stream_options
.
continuous_usage_stats
):
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
0
,
total_tokens
=
prompt_tokens
)
chunk
.
usage
=
usage
else
:
chunk
.
usage
=
None
data
=
chunk
.
model_dump_json
(
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
yield
f
"data:
{
data
}
\n\n
"
...
@@ -336,7 +345,19 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -336,7 +345,19 @@ class OpenAIServingChat(OpenAIServing):
model
=
model_name
)
model
=
model_name
)
if
(
request
.
stream_options
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
and
request
.
stream_options
.
include_usage
):
chunk
.
usage
=
None
if
(
request
.
stream_options
.
continuous_usage_stats
):
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
)
chunk
.
usage
=
usage
else
:
chunk
.
usage
=
None
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
yield
f
"data:
{
data
}
\n\n
"
else
:
else
:
...
@@ -356,7 +377,18 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -356,7 +377,18 @@ class OpenAIServingChat(OpenAIServing):
model
=
model_name
)
model
=
model_name
)
if
(
request
.
stream_options
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
and
request
.
stream_options
.
include_usage
):
chunk
.
usage
=
None
if
(
request
.
stream_options
.
continuous_usage_stats
):
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
)
chunk
.
usage
=
usage
else
:
chunk
.
usage
=
None
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
yield
f
"data:
{
data
}
\n\n
"
finish_reason_sent
[
i
]
=
True
finish_reason_sent
[
i
]
=
True
...
@@ -404,7 +436,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -404,7 +436,7 @@ class OpenAIServingChat(OpenAIServing):
async
for
res
in
result_generator
:
async
for
res
in
result_generator
:
if
raw_request
is
not
None
and
await
raw_request
.
is_disconnected
():
if
raw_request
is
not
None
and
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
# Abort the request if the client disconnects.
await
self
.
engine
.
abort
(
request_id
)
await
self
.
async_engine_client
.
abort
(
request_id
)
return
self
.
create_error_response
(
"Client disconnected"
)
return
self
.
create_error_response
(
"Client disconnected"
)
final_res
=
res
final_res
=
res
assert
final_res
is
not
None
assert
final_res
is
not
None
...
@@ -480,11 +512,14 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -480,11 +512,14 @@ class OpenAIServingChat(OpenAIServing):
self
,
logprobs
:
Dict
[
int
,
Logprob
],
top_logprobs
:
Optional
[
int
],
self
,
logprobs
:
Dict
[
int
,
Logprob
],
top_logprobs
:
Optional
[
int
],
tokenizer
:
PreTrainedTokenizer
)
->
List
[
ChatCompletionLogProb
]:
tokenizer
:
PreTrainedTokenizer
)
->
List
[
ChatCompletionLogProb
]:
return
[
return
[
ChatCompletionLogProb
(
ChatCompletionLogProb
(
token
=
(
token
:
=
self
.
_get_decoded_token
(
token
=
(
token
:
=
self
.
_get_decoded_token
(
p
[
1
],
p
[
0
],
p
[
1
],
tokenizer
)),
p
[
0
],
logprob
=
max
(
p
[
1
].
logprob
,
-
9999.0
),
tokenizer
,
bytes
=
list
(
token
.
encode
(
"utf-8"
,
errors
=
"replace"
)))
return_as_token_id
=
self
.
return_tokens_as_token_ids
)),
logprob
=
max
(
p
[
1
].
logprob
,
-
9999.0
),
bytes
=
list
(
token
.
encode
(
"utf-8"
,
errors
=
"replace"
)))
for
i
,
p
in
enumerate
(
logprobs
.
items
())
for
i
,
p
in
enumerate
(
logprobs
.
items
())
if
top_logprobs
and
i
<
top_logprobs
if
top_logprobs
and
i
<
top_logprobs
]
]
...
@@ -504,6 +539,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -504,6 +539,8 @@ class OpenAIServingChat(OpenAIServing):
step_top_logprobs
=
top_logprobs
[
i
]
step_top_logprobs
=
top_logprobs
[
i
]
if
step_top_logprobs
is
None
:
if
step_top_logprobs
is
None
:
token
=
tokenizer
.
decode
(
token_id
)
token
=
tokenizer
.
decode
(
token_id
)
if
self
.
return_tokens_as_token_ids
:
token
=
f
"token_id:
{
token_id
}
"
logprobs_content
.
append
(
logprobs_content
.
append
(
ChatCompletionLogProbsContent
(
ChatCompletionLogProbsContent
(
token
=
token
,
token
=
token
,
...
@@ -511,7 +548,9 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -511,7 +548,9 @@ class OpenAIServingChat(OpenAIServing):
else
:
else
:
logprobs_content
.
append
(
logprobs_content
.
append
(
ChatCompletionLogProbsContent
(
ChatCompletionLogProbsContent
(
token
=
step_top_logprobs
[
token_id
].
decoded_token
,
token
=
self
.
_get_decoded_token
(
step_top_logprobs
[
token_id
],
token_id
,
tokenizer
,
self
.
return_tokens_as_token_ids
),
logprob
=
max
(
step_top_logprobs
[
token_id
].
logprob
,
logprob
=
max
(
step_top_logprobs
[
token_id
].
logprob
,
-
9999.0
),
-
9999.0
),
bytes
=
list
(
bytes
=
list
(
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
e661d594
...
@@ -8,7 +8,7 @@ from fastapi import Request
...
@@ -8,7 +8,7 @@ from fastapi import Request
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.
async_llm_engine
import
Async
LLM
Engine
from
vllm.engine.
protocol
import
AsyncEngine
Client
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
...
@@ -24,8 +24,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
...
@@ -24,8 +24,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing
,
OpenAIServing
,
PromptAdapterPath
)
PromptAdapterPath
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
...
@@ -44,20 +42,22 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -44,20 +42,22 @@ class OpenAIServingCompletion(OpenAIServing):
def
__init__
(
def
__init__
(
self
,
self
,
engine
:
Async
LLM
Engine
,
async_engine_client
:
AsyncEngine
Client
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
served_model_names
:
List
[
str
],
*
,
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
request_logger
:
Optional
[
RequestLogger
],
request_logger
:
Optional
[
RequestLogger
],
return_tokens_as_token_ids
:
bool
=
False
,
):
):
super
().
__init__
(
engine
=
engine
,
super
().
__init__
(
async_engine_client
=
async_engine_client
,
model_config
=
model_config
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
served_model_names
=
served_model_names
,
lora_modules
=
lora_modules
,
lora_modules
=
lora_modules
,
prompt_adapters
=
prompt_adapters
,
prompt_adapters
=
prompt_adapters
,
request_logger
=
request_logger
)
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
return_tokens_as_token_ids
)
async
def
create_completion
(
self
,
request
:
CompletionRequest
,
async
def
create_completion
(
self
,
request
:
CompletionRequest
,
raw_request
:
Request
):
raw_request
:
Request
):
...
@@ -91,33 +91,27 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -91,33 +91,27 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_adapter_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
sampling_params
=
request
.
to_sampling_params
()
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
guided_decode_logit_processor
=
(
await
get_guided_decoding_logits_processor
(
guided_decoding_backend
,
request
,
tokenizer
))
if
guided_decode_logit_processor
is
not
None
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
sampling_params
.
logits_processors
.
append
(
guided_decode_logit_processor
)
guided_decode_logits_processor
=
(
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
prompts
=
list
(
prompts
=
list
(
self
.
_tokenize_prompt_input_or_inputs
(
self
.
_tokenize_prompt_input_or_inputs
(
request
,
request
,
tokenizer
,
tokenizer
,
request
.
prompt
,
request
.
prompt
,
truncate_prompt_tokens
=
sampling_params
.
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
,
truncate_prompt_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
))
))
for
i
,
prompt_inputs
in
enumerate
(
prompts
):
for
i
,
prompt_inputs
in
enumerate
(
prompts
):
sampling_params
=
request
.
to_sampling_params
(
tokenizer
,
guided_decode_logits_processor
,
default_max_tokens
=
self
.
max_model_len
-
len
(
prompt_inputs
[
"prompt_token_ids"
]))
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
self
.
_log_inputs
(
request_id_item
,
self
.
_log_inputs
(
request_id_item
,
...
@@ -126,7 +120,8 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -126,7 +120,8 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
)
is_tracing_enabled
=
await
self
.
engine
.
is_tracing_enabled
()
is_tracing_enabled
=
(
await
self
.
async_engine_client
.
is_tracing_enabled
())
trace_headers
=
None
trace_headers
=
None
if
is_tracing_enabled
:
if
is_tracing_enabled
:
trace_headers
=
extract_trace_headers
(
raw_request
.
headers
)
trace_headers
=
extract_trace_headers
(
raw_request
.
headers
)
...
@@ -134,7 +129,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -134,7 +129,7 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request
.
headers
):
raw_request
.
headers
):
log_tracing_disabled_warning
()
log_tracing_disabled_warning
()
generator
=
self
.
engine
.
generate
(
generator
=
self
.
async_engine_client
.
generate
(
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]},
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]},
sampling_params
,
sampling_params
,
request_id_item
,
request_id_item
,
...
@@ -175,7 +170,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -175,7 +170,7 @@ class OpenAIServingCompletion(OpenAIServing):
async
for
i
,
res
in
result_generator
:
async
for
i
,
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
# Abort the request if the client disconnects.
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
await
self
.
async_engine_client
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
return
self
.
create_error_response
(
"Client disconnected"
)
return
self
.
create_error_response
(
"Client disconnected"
)
final_res_batch
[
i
]
=
res
final_res_batch
[
i
]
=
res
...
@@ -237,7 +232,8 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -237,7 +232,8 @@ class OpenAIServingCompletion(OpenAIServing):
# Abort the request if the client disconnects.
# Abort the request if the client disconnects.
if
await
raw_request
.
is_disconnected
():
if
await
raw_request
.
is_disconnected
():
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
prompt_idx
}
"
)
await
self
.
async_engine_client
.
abort
(
f
"
{
request_id
}
-
{
prompt_idx
}
"
)
raise
StopAsyncIteration
()
raise
StopAsyncIteration
()
for
output
in
res
.
outputs
:
for
output
in
res
.
outputs
:
...
@@ -430,12 +426,17 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -430,12 +426,17 @@ class OpenAIServingCompletion(OpenAIServing):
step_top_logprobs
=
top_logprobs
[
i
]
step_top_logprobs
=
top_logprobs
[
i
]
if
step_top_logprobs
is
None
:
if
step_top_logprobs
is
None
:
token
=
tokenizer
.
decode
(
token_id
)
token
=
tokenizer
.
decode
(
token_id
)
if
self
.
return_tokens_as_token_ids
:
token
=
f
"token_id:
{
token_id
}
"
out_tokens
.
append
(
token
)
out_tokens
.
append
(
token
)
out_token_logprobs
.
append
(
None
)
out_token_logprobs
.
append
(
None
)
out_top_logprobs
.
append
(
None
)
out_top_logprobs
.
append
(
None
)
else
:
else
:
token
=
self
.
_get_decoded_token
(
step_top_logprobs
[
token_id
],
token
=
self
.
_get_decoded_token
(
token_id
,
tokenizer
)
step_top_logprobs
[
token_id
],
token_id
,
tokenizer
,
return_as_token_id
=
self
.
return_tokens_as_token_ids
)
token_logprob
=
max
(
step_top_logprobs
[
token_id
].
logprob
,
token_logprob
=
max
(
step_top_logprobs
[
token_id
].
logprob
,
-
9999.0
)
-
9999.0
)
out_tokens
.
append
(
token
)
out_tokens
.
append
(
token
)
...
@@ -448,7 +449,11 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -448,7 +449,11 @@ class OpenAIServingCompletion(OpenAIServing):
out_top_logprobs
.
append
({
out_top_logprobs
.
append
({
# Convert float("-inf") to the
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
# JSON-serializable float that OpenAI uses
self
.
_get_decoded_token
(
top_lp
[
1
],
top_lp
[
0
],
tokenizer
):
self
.
_get_decoded_token
(
top_lp
[
1
],
top_lp
[
0
],
tokenizer
,
return_as_token_id
=
self
.
return_tokens_as_token_ids
):
max
(
top_lp
[
1
].
logprob
,
-
9999.0
)
max
(
top_lp
[
1
].
logprob
,
-
9999.0
)
for
i
,
top_lp
in
enumerate
(
step_top_logprobs
.
items
())
for
i
,
top_lp
in
enumerate
(
step_top_logprobs
.
items
())
if
num_output_top_logprobs
>=
i
if
num_output_top_logprobs
>=
i
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
e661d594
...
@@ -6,7 +6,7 @@ import numpy as np
...
@@ -6,7 +6,7 @@ import numpy as np
from
fastapi
import
Request
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.
async_llm_engine
import
Async
LLM
Engine
from
vllm.engine.
protocol
import
AsyncEngine
Client
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
EmbeddingRequest
,
from
vllm.entrypoints.openai.protocol
import
(
EmbeddingRequest
,
EmbeddingResponse
,
EmbeddingResponse
,
...
@@ -56,13 +56,13 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -56,13 +56,13 @@ class OpenAIServingEmbedding(OpenAIServing):
def
__init__
(
def
__init__
(
self
,
self
,
engine
:
Async
LLM
Engine
,
async_engine_client
:
AsyncEngine
Client
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
served_model_names
:
List
[
str
],
*
,
*
,
request_logger
:
Optional
[
RequestLogger
],
request_logger
:
Optional
[
RequestLogger
],
):
):
super
().
__init__
(
engine
=
engine
,
super
().
__init__
(
async_engine_client
=
async_engine_client
,
model_config
=
model_config
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
served_model_names
=
served_model_names
,
lora_modules
=
None
,
lora_modules
=
None
,
...
@@ -99,7 +99,8 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -99,7 +99,8 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_adapter_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
pooling_params
=
request
.
to_pooling_params
()
pooling_params
=
request
.
to_pooling_params
()
...
@@ -124,7 +125,7 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -124,7 +125,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"Prompt adapter is not supported "
"Prompt adapter is not supported "
"for embedding models"
)
"for embedding models"
)
generator
=
self
.
engine
.
encode
(
generator
=
self
.
async_engine_client
.
encode
(
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]},
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]},
pooling_params
,
pooling_params
,
request_id_item
,
request_id_item
,
...
@@ -146,7 +147,7 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -146,7 +147,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async
for
i
,
res
in
result_generator
:
async
for
i
,
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
# Abort the request if the client disconnects.
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
await
self
.
async_engine_client
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
return
self
.
create_error_response
(
"Client disconnected"
)
return
self
.
create_error_response
(
"Client disconnected"
)
final_res_batch
[
i
]
=
res
final_res_batch
[
i
]
=
res
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
e661d594
...
@@ -5,11 +5,10 @@ from http import HTTPStatus
...
@@ -5,11 +5,10 @@ from http import HTTPStatus
from
typing
import
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
,
TypedDict
,
Union
from
typing
import
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
,
TypedDict
,
Union
from
pydantic
import
Field
from
pydantic
import
Field
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
typing_extensions
import
Annotated
from
typing_extensions
import
Annotated
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.
async_llm_engine
import
Async
LLM
Engine
from
vllm.engine.
protocol
import
AsyncEngine
Client
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
...
@@ -26,10 +25,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
...
@@ -26,10 +25,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from
vllm.inputs
import
parse_and_batch_prompt
from
vllm.inputs
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer_group
import
AnyTokenizer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -49,8 +51,6 @@ class LoRAModulePath:
...
@@ -49,8 +51,6 @@ class LoRAModulePath:
AnyRequest
=
Union
[
ChatCompletionRequest
,
CompletionRequest
,
DetokenizeRequest
,
AnyRequest
=
Union
[
ChatCompletionRequest
,
CompletionRequest
,
DetokenizeRequest
,
EmbeddingRequest
,
TokenizeRequest
]
EmbeddingRequest
,
TokenizeRequest
]
AnyTokenizer
=
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
class
TextTokensPrompt
(
TypedDict
):
class
TextTokensPrompt
(
TypedDict
):
prompt
:
str
prompt
:
str
...
@@ -61,17 +61,18 @@ class OpenAIServing:
...
@@ -61,17 +61,18 @@ class OpenAIServing:
def
__init__
(
def
__init__
(
self
,
self
,
engine
:
Async
LLM
Engine
,
async_engine_client
:
AsyncEngine
Client
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
served_model_names
:
List
[
str
],
*
,
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
request_logger
:
Optional
[
RequestLogger
],
request_logger
:
Optional
[
RequestLogger
],
return_tokens_as_token_ids
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
engine
=
engine
self
.
async_engine_client
=
async_engine_client
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
max_model_len
=
model_config
.
max_model_len
self
.
max_model_len
=
model_config
.
max_model_len
...
@@ -102,6 +103,7 @@ class OpenAIServing:
...
@@ -102,6 +103,7 @@ class OpenAIServing:
prompt_adapter_num_virtual_tokens
=
num_virtual_tokens
))
prompt_adapter_num_virtual_tokens
=
num_virtual_tokens
))
self
.
request_logger
=
request_logger
self
.
request_logger
=
request_logger
self
.
return_tokens_as_token_ids
=
return_tokens_as_token_ids
async
def
show_available_models
(
self
)
->
ModelList
:
async
def
show_available_models
(
self
)
->
ModelList
:
"""Show available models. Right now we only have one model."""
"""Show available models. Right now we only have one model."""
...
@@ -150,6 +152,15 @@ class OpenAIServing:
...
@@ -150,6 +152,15 @@ class OpenAIServing:
})
})
return
json_str
return
json_str
async
def
_guided_decode_logits_processor
(
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
tokenizer
:
AnyTokenizer
)
->
Optional
[
LogitsProcessor
]:
decoding_config
=
await
self
.
async_engine_client
.
get_decoding_config
()
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
return
await
get_guided_decoding_logits_processor
(
guided_decoding_backend
,
request
,
tokenizer
)
async
def
_check_model
(
async
def
_check_model
(
self
,
self
,
request
:
AnyRequest
,
request
:
AnyRequest
,
...
@@ -254,9 +265,7 @@ class OpenAIServing:
...
@@ -254,9 +265,7 @@ class OpenAIServing:
f
"
{
self
.
max_model_len
}
tokens. However, you requested "
f
"
{
self
.
max_model_len
}
tokens. However, you requested "
f
"
{
token_num
}
tokens in the messages, "
f
"
{
token_num
}
tokens in the messages, "
f
"Please reduce the length of the messages."
)
f
"Please reduce the length of the messages."
)
request
.
max_tokens
=
self
.
max_model_len
-
token_num
elif
token_num
+
request
.
max_tokens
>
self
.
max_model_len
:
if
token_num
+
request
.
max_tokens
>
self
.
max_model_len
:
raise
ValueError
(
raise
ValueError
(
f
"This model's maximum context length is "
f
"This model's maximum context length is "
f
"
{
self
.
max_model_len
}
tokens. However, you requested "
f
"
{
self
.
max_model_len
}
tokens. However, you requested "
...
@@ -384,11 +393,13 @@ class OpenAIServing:
...
@@ -384,11 +393,13 @@ class OpenAIServing:
)
)
@
staticmethod
@
staticmethod
def
_get_decoded_token
(
def
_get_decoded_token
(
logprob
:
Logprob
,
logprob
:
Logprob
,
token_id
:
int
,
token_id
:
int
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
return_as_token_id
:
bool
=
False
)
->
str
:
)
->
str
:
if
return_as_token_id
:
return
f
"token_id:
{
token_id
}
"
if
logprob
.
decoded_token
is
not
None
:
if
logprob
.
decoded_token
is
not
None
:
return
logprob
.
decoded_token
return
logprob
.
decoded_token
return
tokenizer
.
decode
(
token_id
)
return
tokenizer
.
decode
(
token_id
)
vllm/entrypoints/openai/serving_tokenization.py
View file @
e661d594
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.chat_utils
import
load_chat_template
,
parse_chat_messages
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
load_chat_template
,
parse_chat_message_content
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
DetokenizeRequest
,
from
vllm.entrypoints.openai.protocol
import
(
DetokenizeRequest
,
DetokenizeResponse
,
DetokenizeResponse
,
ErrorResponse
,
ErrorResponse
,
...
@@ -17,14 +15,17 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
...
@@ -17,14 +15,17 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
# yapf: enable
# yapf: enable
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
)
OpenAIServing
)
from
vllm.logger
import
init_logger
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
class
OpenAIServingTokenization
(
OpenAIServing
):
class
OpenAIServingTokenization
(
OpenAIServing
):
def
__init__
(
def
__init__
(
self
,
self
,
engine
:
Async
LLM
Engine
,
async_engine_client
:
AsyncEngine
Client
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
served_model_names
:
List
[
str
],
*
,
*
,
...
@@ -32,7 +33,7 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -32,7 +33,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_logger
:
Optional
[
RequestLogger
],
request_logger
:
Optional
[
RequestLogger
],
chat_template
:
Optional
[
str
],
chat_template
:
Optional
[
str
],
):
):
super
().
__init__
(
engine
=
engine
,
super
().
__init__
(
async_engine_client
=
async_engine_client
,
model_config
=
model_config
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
served_model_names
=
served_model_names
,
lora_modules
=
lora_modules
,
lora_modules
=
lora_modules
,
...
@@ -57,17 +58,17 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -57,17 +58,17 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
if
isinstance
(
request
,
TokenizeChatRequest
):
if
isinstance
(
request
,
TokenizeChatRequest
):
model_config
=
self
.
model_config
model_config
=
self
.
model_config
conversation
:
List
[
ConversationMessage
]
=
[]
conversation
,
mm_futures
=
parse_chat_messages
(
request
.
messages
,
model_config
,
tokenizer
)
for
message
in
request
.
messages
:
if
mm_futures
:
result
=
parse_chat_message_content
(
message
,
model_config
,
logger
.
warning
(
tokenizer
)
"Multi-modal inputs are ignored during tokenization"
)
conversation
.
extend
(
result
.
messages
)
prompt
=
tokenizer
.
apply_chat_template
(
prompt
=
tokenizer
.
apply_chat_template
(
add_generation_prompt
=
request
.
add_generation_prompt
,
add_generation_prompt
=
request
.
add_generation_prompt
,
...
@@ -113,7 +114,7 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -113,7 +114,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
self
.
_log_inputs
(
request_id
,
self
.
_log_inputs
(
request_id
,
request
.
tokens
,
request
.
tokens
,
...
...
vllm/envs.py
View file @
e661d594
...
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
...
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
VLLM_HOST_IP
:
str
=
""
VLLM_HOST_IP
:
str
=
""
VLLM_PORT
:
Optional
[
int
]
=
None
VLLM_PORT
:
Optional
[
int
]
=
None
VLLM_RPC_PORT
:
int
=
5570
VLLM_USE_MODELSCOPE
:
bool
=
False
VLLM_USE_MODELSCOPE
:
bool
=
False
VLLM_RINGBUFFER_WARNING_INTERVAL
:
int
=
60
VLLM_RINGBUFFER_WARNING_INTERVAL
:
int
=
60
VLLM_INSTANCE_ID
:
Optional
[
str
]
=
None
VLLM_INSTANCE_ID
:
Optional
[
str
]
=
None
...
@@ -28,7 +29,9 @@ if TYPE_CHECKING:
...
@@ -28,7 +29,9 @@ if TYPE_CHECKING:
VLLM_LOGGING_CONFIG_PATH
:
Optional
[
str
]
=
None
VLLM_LOGGING_CONFIG_PATH
:
Optional
[
str
]
=
None
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_CPU_OMP_THREADS_BIND
:
str
=
""
VLLM_OPENVINO_KVCACHE_SPACE
:
int
=
0
VLLM_OPENVINO_KVCACHE_SPACE
:
int
=
0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION
:
Optional
[
str
]
=
None
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION
:
Optional
[
str
]
=
None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS
:
bool
=
False
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS
:
bool
=
False
...
@@ -36,6 +39,7 @@ if TYPE_CHECKING:
...
@@ -36,6 +39,7 @@ if TYPE_CHECKING:
VLLM_FUSED_MOE_CHUNK_SIZE
:
int
=
64
*
1024
VLLM_FUSED_MOE_CHUNK_SIZE
:
int
=
64
*
1024
VLLM_USE_RAY_SPMD_WORKER
:
bool
=
False
VLLM_USE_RAY_SPMD_WORKER
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL
:
bool
=
True
VLLM_WORKER_MULTIPROC_METHOD
:
str
=
"fork"
VLLM_WORKER_MULTIPROC_METHOD
:
str
=
"fork"
VLLM_ASSETS_CACHE
:
str
=
os
.
path
.
join
(
VLLM_CACHE_ROOT
,
"assets"
)
VLLM_ASSETS_CACHE
:
str
=
os
.
path
.
join
(
VLLM_CACHE_ROOT
,
"assets"
)
VLLM_IMAGE_FETCH_TIMEOUT
:
int
=
5
VLLM_IMAGE_FETCH_TIMEOUT
:
int
=
5
...
@@ -43,10 +47,10 @@ if TYPE_CHECKING:
...
@@ -43,10 +47,10 @@ if TYPE_CHECKING:
MAX_JOBS
:
Optional
[
str
]
=
None
MAX_JOBS
:
Optional
[
str
]
=
None
NVCC_THREADS
:
Optional
[
str
]
=
None
NVCC_THREADS
:
Optional
[
str
]
=
None
VLLM_USE_PRECOMPILED
:
bool
=
False
VLLM_USE_PRECOMPILED
:
bool
=
False
VLLM_INSTALL_PUNICA_KERNELS
:
bool
=
False
VLLM_NO_DEPRECATION_WARNING
:
bool
=
False
VLLM_NO_DEPRECATION_WARNING
:
bool
=
False
CMAKE_BUILD_TYPE
:
Optional
[
str
]
=
None
CMAKE_BUILD_TYPE
:
Optional
[
str
]
=
None
VERBOSE
:
bool
=
False
VERBOSE
:
bool
=
False
VLLM_ALLOW_LONG_MAX_MODEL_LEN
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -92,10 +96,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -92,10 +96,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_USE_PRECOMPILED"
:
"VLLM_USE_PRECOMPILED"
:
lambda
:
bool
(
os
.
environ
.
get
(
"VLLM_USE_PRECOMPILED"
)),
lambda
:
bool
(
os
.
environ
.
get
(
"VLLM_USE_PRECOMPILED"
)),
# If set, vllm will install Punica kernels
"VLLM_INSTALL_PUNICA_KERNELS"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_INSTALL_PUNICA_KERNELS"
,
"0"
))),
# CMake build type
# CMake build type
# If not set, defaults to "Debug" or "RelWithDebInfo"
# If not set, defaults to "Debug" or "RelWithDebInfo"
# Available options: "Debug", "Release", "RelWithDebInfo"
# Available options: "Debug", "Release", "RelWithDebInfo"
...
@@ -142,6 +142,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -142,6 +142,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
int
(
os
.
getenv
(
'VLLM_PORT'
,
'0'
))
lambda
:
int
(
os
.
getenv
(
'VLLM_PORT'
,
'0'
))
if
'VLLM_PORT'
in
os
.
environ
else
None
,
if
'VLLM_PORT'
in
os
.
environ
else
None
,
# used when the frontend api server is running in multi-processing mode,
# to communicate with the backend engine process over ZMQ.
'VLLM_RPC_PORT'
:
lambda
:
int
(
os
.
getenv
(
'VLLM_PORT'
,
'5570'
)),
# If true, will load models from ModelScope instead of Hugging Face Hub.
# If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers
# note that the value is true or false, not numbers
"VLLM_USE_MODELSCOPE"
:
"VLLM_USE_MODELSCOPE"
:
...
@@ -181,6 +186,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -181,6 +186,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
,
"0"
)),
# local rank of the process in the distributed setting, used to determine
# local rank of the process in the distributed setting, used to determine
# the GPU device id
# the GPU device id
"LOCAL_RANK"
:
"LOCAL_RANK"
:
...
@@ -246,11 +255,20 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -246,11 +255,20 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_ATTENTION_BACKEND"
:
"VLLM_ATTENTION_BACKEND"
:
lambda
:
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
,
None
),
lambda
:
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
,
None
),
# CPU key-value cache space
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION"
:
lambda
:
os
.
getenv
(
"VLLM_PP_LAYER_PARTITION"
,
None
),
# (CPU backend only) CPU key-value cache space.
# default is 4GB
# default is 4GB
"VLLM_CPU_KVCACHE_SPACE"
:
"VLLM_CPU_KVCACHE_SPACE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_CPU_KVCACHE_SPACE"
,
"0"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_CPU_KVCACHE_SPACE"
,
"0"
)),
# (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31",
# "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'.
"VLLM_CPU_OMP_THREADS_BIND"
:
lambda
:
os
.
getenv
(
"VLLM_CPU_OMP_THREADS_BIND"
,
"all"
),
# OpenVINO key-value cache space
# OpenVINO key-value cache space
# default is 4GB
# default is 4GB
"VLLM_OPENVINO_KVCACHE_SPACE"
:
"VLLM_OPENVINO_KVCACHE_SPACE"
:
...
@@ -272,13 +290,20 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -272,13 +290,20 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# execution on all workers.
# execution on all workers.
# Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it.
# Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it.
"VLLM_USE_RAY_SPMD_WORKER"
:
"VLLM_USE_RAY_SPMD_WORKER"
:
lambda
:
bool
(
os
.
getenv
(
"VLLM_USE_RAY_SPMD_WORKER"
,
0
)),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_RAY_SPMD_WORKER"
,
"0"
)
)),
# If the env var is set, it uses the Ray's compiled DAG API
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
"VLLM_USE_RAY_COMPILED_DAG"
:
"VLLM_USE_RAY_COMPILED_DAG"
:
lambda
:
bool
(
os
.
getenv
(
"VLLM_USE_RAY_COMPILED_DAG"
,
0
)),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_RAY_COMPILED_DAG"
,
"0"
))),
# If the env var is set, it uses NCCL for communication in
# Ray's compiled DAG. This flag is ignored if
# VLLM_USE_RAY_COMPILED_DAG is not set.
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"
,
"1"
))
),
# Use dedicated multiprocess context for workers.
# Use dedicated multiprocess context for workers.
# Both spawn and fork work
# Both spawn and fork work
...
@@ -312,6 +337,15 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -312,6 +337,15 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, vllm will skip the deprecation warnings.
# If set, vllm will skip the deprecation warnings.
"VLLM_NO_DEPRECATION_WARNING"
:
"VLLM_NO_DEPRECATION_WARNING"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_NO_DEPRECATION_WARNING"
,
"0"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_NO_DEPRECATION_WARNING"
,
"0"
))),
# If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows
# the user to specify a max sequence length greater than
# the max length derived from the model's config.json.
# To enable this, set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1.
"VLLM_ALLOW_LONG_MAX_MODEL_LEN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_ALLOW_LONG_MAX_MODEL_LEN"
,
"0"
).
strip
().
lower
()
in
(
"1"
,
"true"
)),
}
}
# end-env-vars-definition
# end-env-vars-definition
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment