Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e661d594
Commit
e661d594
authored
Aug 12, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1
parents
6b16ea2e
4db5176d
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1292 additions
and
285 deletions
+1292
-285
vllm/engine/metrics.py
vllm/engine/metrics.py
+31
-16
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+31
-8
vllm/engine/protocol.py
vllm/engine/protocol.py
+84
-0
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+53
-24
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+28
-8
vllm/entrypoints/launcher.py
vllm/entrypoints/launcher.py
+46
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+43
-1
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+109
-56
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+11
-0
vllm/entrypoints/openai/logits_processors.py
vllm/entrypoints/openai/logits_processors.py
+83
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+56
-54
vllm/entrypoints/openai/rpc/__init__.py
vllm/entrypoints/openai/rpc/__init__.py
+42
-0
vllm/entrypoints/openai/rpc/client.py
vllm/entrypoints/openai/rpc/client.py
+248
-0
vllm/entrypoints/openai/rpc/server.py
vllm/entrypoints/openai/rpc/server.py
+218
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+84
-45
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+35
-30
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+7
-6
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+26
-15
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+15
-14
vllm/envs.py
vllm/envs.py
+42
-8
No files found.
vllm/engine/metrics.py
View file @
e661d594
...
...
@@ -355,6 +355,7 @@ class StatLoggerBase(ABC):
self
.
num_generation_tokens
:
List
[
int
]
=
[]
self
.
last_local_log
=
time
.
time
()
self
.
local_interval
=
local_interval
self
.
spec_decode_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
@
abstractmethod
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
...
...
@@ -364,6 +365,12 @@ class StatLoggerBase(ABC):
def
log
(
self
,
stats
:
Stats
)
->
None
:
raise
NotImplementedError
def
maybe_update_spec_decode_metrics
(
self
,
stats
:
Stats
):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if
stats
.
spec_decode_metrics
is
not
None
:
self
.
spec_decode_metrics
=
stats
.
spec_decode_metrics
class
LoggingStatLogger
(
StatLoggerBase
):
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
...
...
@@ -379,6 +386,9 @@ class LoggingStatLogger(StatLoggerBase):
self
.
num_prompt_tokens
.
append
(
stats
.
num_prompt_tokens_iter
)
self
.
num_generation_tokens
.
append
(
stats
.
num_generation_tokens_iter
)
# Update spec decode metrics
self
.
maybe_update_spec_decode_metrics
(
stats
)
# Log locally every local_interval seconds.
if
local_interval_elapsed
(
stats
.
now
,
self
.
last_local_log
,
self
.
local_interval
):
...
...
@@ -408,15 +418,16 @@ class LoggingStatLogger(StatLoggerBase):
stats
.
cpu_cache_usage_sys
*
100
,
)
if
self
.
spec_decode_metrics
is
not
None
:
logger
.
info
(
self
.
_format_spec_decode_metrics_str
(
self
.
spec_decode_metrics
))
# Reset tracked stats for next interval.
self
.
num_prompt_tokens
=
[]
self
.
num_generation_tokens
=
[]
self
.
last_local_log
=
stats
.
now
if
stats
.
spec_decode_metrics
is
not
None
:
logger
.
info
(
self
.
_format_spec_decode_metrics_str
(
stats
.
spec_decode_metrics
))
self
.
spec_decode_metrics
=
None
def
_format_spec_decode_metrics_str
(
self
,
metrics
:
"SpecDecodeWorkerMetrics"
)
->
str
:
...
...
@@ -533,6 +544,9 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
num_prompt_tokens
.
append
(
stats
.
num_prompt_tokens_iter
)
self
.
num_generation_tokens
.
append
(
stats
.
num_generation_tokens_iter
)
# Update spec decode metrics
self
.
maybe_update_spec_decode_metrics
(
stats
)
# Log locally every local_interval seconds.
if
local_interval_elapsed
(
stats
.
now
,
self
.
last_local_log
,
self
.
local_interval
):
...
...
@@ -550,26 +564,27 @@ class PrometheusStatLogger(StatLoggerBase):
prompt_throughput
=
prompt_throughput
,
generation_throughput
=
generation_throughput
)
# Reset tracked stats for next interval.
self
.
num_prompt_tokens
=
[]
self
.
num_generation_tokens
=
[]
self
.
last_local_log
=
stats
.
now
if
stats
.
spec_decode_metrics
is
not
None
:
if
self
.
spec_decode_metrics
is
not
None
:
self
.
_log_gauge
(
self
.
metrics
.
gauge_spec_decode_draft_acceptance_rate
,
s
tats
.
spec_decode_metrics
.
draft_acceptance_rate
)
s
elf
.
spec_decode_metrics
.
draft_acceptance_rate
)
self
.
_log_gauge
(
self
.
metrics
.
gauge_spec_decode_efficiency
,
s
tats
.
spec_decode_metrics
.
system_efficiency
)
s
elf
.
spec_decode_metrics
.
system_efficiency
)
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_accepted_tokens
,
s
tats
.
spec_decode_metrics
.
accepted_tokens
)
s
elf
.
spec_decode_metrics
.
accepted_tokens
)
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_draft_tokens
,
s
tats
.
spec_decode_metrics
.
draft_tokens
)
s
elf
.
spec_decode_metrics
.
draft_tokens
)
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_emitted_tokens
,
stats
.
spec_decode_metrics
.
emitted_tokens
)
self
.
spec_decode_metrics
.
emitted_tokens
)
# Reset tracked stats for next interval.
self
.
num_prompt_tokens
=
[]
self
.
num_generation_tokens
=
[]
self
.
last_local_log
=
stats
.
now
self
.
spec_decode_metrics
=
None
class
RayPrometheusStatLogger
(
PrometheusStatLogger
):
...
...
vllm/engine/output_processor/single_step.py
View file @
e661d594
...
...
@@ -81,6 +81,29 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
SequenceGroupOutput
)
->
None
:
sampling_params
=
seq_group
.
sampling_params
if
sampling_params
.
n
==
1
and
not
sampling_params
.
use_beam_search
:
# only have one output sample
sample
=
outputs
.
samples
[
0
]
# only have one sequence
seq
=
seq_group
.
seqs
[
0
]
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
if
sampling_params
.
detokenize
and
self
.
detokenizer
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
sampling_params
)
else
:
new_char_count
=
0
self
.
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
,
sampling_params
,
lora_req
=
seq_group
.
lora_request
,
)
if
seq
.
is_finished
():
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_seq
(
seq
)
return
# Process samples
samples
=
outputs
.
samples
parent_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
...
...
@@ -127,20 +150,20 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
child_seqs
.
append
((
parent
,
parent
))
for
seq
,
_
in
child_seqs
:
if
seq_group
.
sampling_params
.
detokenize
and
self
.
detokenizer
:
if
sampling_params
.
detokenize
and
self
.
detokenizer
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
seq_group
.
sampling_params
)
seq
,
sampling_params
)
else
:
new_char_count
=
0
self
.
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
,
seq_group
.
sampling_params
,
sampling_params
,
lora_req
=
seq_group
.
lora_request
,
)
# Non-beam search case
if
not
seq_group
.
sampling_params
.
use_beam_search
:
if
not
sampling_params
.
use_beam_search
:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for
seq
,
parent
in
child_seqs
:
...
...
@@ -164,8 +187,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Select the child sequences to keep in the sequence group.
selected_child_seqs
:
List
[
Tuple
[
Sequence
,
Optional
[
Sequence
]]]
=
[]
unselected_child_seqs
:
List
[
Tuple
[
Sequence
,
Optional
[
Sequence
]]]
=
[]
beam_width
=
seq_group
.
sampling_params
.
best_of
length_penalty
=
seq_group
.
sampling_params
.
length_penalty
beam_width
=
sampling_params
.
best_of
length_penalty
=
sampling_params
.
length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
...
...
@@ -219,8 +242,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
best_running_seq
=
running_child_seqs
[
0
][
0
]
current_worst_seq
=
all_finished_seqs
[
beam_width
-
1
][
0
]
stop_beam_search
=
self
.
_check_beam_search_early_stopping
(
seq_group
.
sampling_params
.
early_stopping
,
seq_group
.
sampling_params
,
best_running_seq
,
current_worst_seq
)
sampling_params
.
early_stopping
,
sampling_params
,
best_running_seq
,
current_worst_seq
)
if
stop_beam_search
:
# Stop the beam search and remove all the running sequences from
...
...
vllm/engine/protocol.py
0 → 100644
View file @
e661d594
from
typing
import
(
AsyncIterator
,
List
,
Mapping
,
Optional
,
Protocol
,
runtime_checkable
)
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.inputs.data
import
PromptInputs
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SamplerOutput
@
runtime_checkable
class
AsyncEngineClient
(
Protocol
):
"""Protocol class for Clients to AsyncLLMEngine"""
@
property
def
is_running
(
self
)
->
bool
:
...
@
property
def
is_stopped
(
self
)
->
bool
:
...
@
property
def
errored
(
self
)
->
bool
:
...
async
def
generate
(
self
,
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
AsyncIterator
[
RequestOutput
]:
"""Generates outputs for a request"""
async
def
encode
(
self
,
inputs
:
PromptInputs
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
AsyncIterator
[
EmbeddingRequestOutput
]:
"""Generate outputs for a request from an embedding model."""
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
"""Abort a request.
Args:
request_id: The unique id of the request.
"""
async
def
get_model_config
(
self
)
->
ModelConfig
:
"""Get the model configuration of the vLLM engine."""
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
"""Get the decoding configuration of the vLLM engine."""
async
def
get_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
PreTrainedTokenizer
:
"""Get the appropriate Tokenizer for the request"""
async
def
is_tracing_enabled
(
self
)
->
bool
:
pass
async
def
do_log_stats
(
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
,
)
->
None
:
pass
async
def
check_health
(
self
)
->
None
:
"""Raise if unhealthy"""
vllm/entrypoints/api_server.py
View file @
e661d594
...
...
@@ -5,21 +5,23 @@ For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
"""
import
asyncio
import
json
import
ssl
from
typing
import
AsyncGenerator
from
argparse
import
Namespace
from
typing
import
Any
,
AsyncGenerator
,
Optional
import
uvicorn
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.launcher
import
serve_http
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
random_uuid
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
"vllm.entrypoints.api_server"
)
...
...
@@ -81,6 +83,53 @@ async def generate(request: Request) -> Response:
return
JSONResponse
(
ret
)
def
build_app
(
args
:
Namespace
)
->
FastAPI
:
global
app
app
.
root_path
=
args
.
root_path
return
app
async
def
init_app
(
args
:
Namespace
,
llm_engine
:
Optional
[
AsyncLLMEngine
]
=
None
,
)
->
FastAPI
:
app
=
build_app
(
args
)
global
engine
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
(
llm_engine
if
llm_engine
is
not
None
else
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
API_SERVER
))
return
app
async
def
run_server
(
args
:
Namespace
,
llm_engine
:
Optional
[
AsyncLLMEngine
]
=
None
,
**
uvicorn_kwargs
:
Any
)
->
None
:
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"args: %s"
,
args
)
app
=
await
init_app
(
args
,
llm_engine
)
shutdown_task
=
await
serve_http
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
args
.
log_level
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
,
ssl_keyfile
=
args
.
ssl_keyfile
,
ssl_certfile
=
args
.
ssl_certfile
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
ssl_cert_reqs
=
args
.
ssl_cert_reqs
,
**
uvicorn_kwargs
,
)
await
shutdown_task
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
None
)
...
...
@@ -105,25 +154,5 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--log-level"
,
type
=
str
,
default
=
"debug"
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
API_SERVER
)
app
.
root_path
=
args
.
root_path
logger
.
info
(
"Available routes are:"
)
for
route
in
app
.
routes
:
if
not
hasattr
(
route
,
'methods'
):
continue
methods
=
', '
.
join
(
route
.
methods
)
logger
.
info
(
"Route: %s, Methods: %s"
,
route
.
path
,
methods
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
args
.
log_level
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
,
ssl_keyfile
=
args
.
ssl_keyfile
,
ssl_certfile
=
args
.
ssl_certfile
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
ssl_cert_reqs
=
args
.
ssl_cert_reqs
)
asyncio
.
run
(
run_server
(
args
))
vllm/entrypoints/chat_utils.py
View file @
e661d594
import
codecs
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
typing
import
Awaitable
,
Iterable
,
List
,
Optional
,
Union
,
cast
,
final
from
typing
import
(
Awaitable
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
cast
,
final
)
# yapf conflicts with isort for this block
# yapf: disable
...
...
@@ -65,8 +66,7 @@ class ConversationMessage(TypedDict):
@
dataclass
(
frozen
=
True
)
class
ChatMessageParseResult
:
messages
:
List
[
ConversationMessage
]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
field
(
default_factory
=
list
)
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
def
load_chat_template
(
chat_template
:
Optional
[
str
])
->
Optional
[
str
]:
...
...
@@ -100,14 +100,16 @@ def _image_token_str(model_config: ModelConfig,
if
model_type
==
"phi3_v"
:
# Workaround since this token is not defined in the tokenizer
return
"<|image_1|>"
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"minicpmv"
,
"paligemma"
):
if
model_type
==
"minicpmv"
:
return
"(<image>./</image>)"
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"paligemma"
):
# These models do not use image tokens in the prompt
return
None
if
model_type
.
startswith
(
"llava"
):
return
tokenizer
.
decode
(
model_config
.
hf_config
.
image_token_index
)
if
model_type
==
"chameleon"
:
if
model_type
in
(
"chameleon"
,
"internvl_chat"
)
:
return
"<image>"
raise
TypeError
(
"Unknown model type: {model_type}"
)
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
# TODO: Let user specify how to insert image tokens into prompt
...
...
@@ -172,7 +174,7 @@ def _parse_chat_message_content_parts(
return
ChatMessageParseResult
(
messages
=
messages
,
mm_futures
=
mm_futures
)
def
parse_chat_message_content
(
def
_
parse_chat_message_content
(
message
:
ChatCompletionMessageParam
,
model_config
:
ModelConfig
,
tokenizer
:
PreTrainedTokenizer
,
...
...
@@ -188,3 +190,21 @@ def parse_chat_message_content(
return
_parse_chat_message_content_parts
(
role
,
content
,
model_config
,
tokenizer
)
def
parse_chat_messages
(
messages
:
List
[
ChatCompletionMessageParam
],
model_config
:
ModelConfig
,
tokenizer
:
PreTrainedTokenizer
,
)
->
Tuple
[
List
[
ConversationMessage
],
List
[
Awaitable
[
MultiModalDataDict
]]]:
conversation
:
List
[
ConversationMessage
]
=
[]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
[]
for
msg
in
messages
:
parse_result
=
_parse_chat_message_content
(
msg
,
model_config
,
tokenizer
)
conversation
.
extend
(
parse_result
.
messages
)
mm_futures
.
extend
(
parse_result
.
mm_futures
)
return
conversation
,
mm_futures
vllm/entrypoints/launcher.py
0 → 100644
View file @
e661d594
import
asyncio
import
signal
from
typing
import
Any
import
uvicorn
from
fastapi
import
FastAPI
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
async
def
serve_http
(
app
:
FastAPI
,
**
uvicorn_kwargs
:
Any
):
logger
.
info
(
"Available routes are:"
)
for
route
in
app
.
routes
:
methods
=
getattr
(
route
,
"methods"
,
None
)
path
=
getattr
(
route
,
"path"
,
None
)
if
methods
is
None
or
path
is
None
:
continue
logger
.
info
(
"Route: %s, Methods: %s"
,
path
,
', '
.
join
(
methods
))
config
=
uvicorn
.
Config
(
app
,
**
uvicorn_kwargs
)
server
=
uvicorn
.
Server
(
config
)
loop
=
asyncio
.
get_running_loop
()
server_task
=
loop
.
create_task
(
server
.
serve
())
def
signal_handler
()
->
None
:
# prevents the uvicorn signal handler to exit early
server_task
.
cancel
()
async
def
dummy_shutdown
()
->
None
:
pass
loop
.
add_signal_handler
(
signal
.
SIGINT
,
signal_handler
)
loop
.
add_signal_handler
(
signal
.
SIGTERM
,
signal_handler
)
try
:
await
server_task
return
dummy_shutdown
()
except
asyncio
.
CancelledError
:
logger
.
info
(
"Gracefully stopping http server"
)
return
server
.
shutdown
()
vllm/entrypoints/llm.py
View file @
e661d594
...
...
@@ -10,6 +10,9 @@ from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
parse_and_batch_prompt
)
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
GuidedDecodingRequest
,
get_local_guided_decoding_logits_processor
)
from
vllm.model_executor.guided_decoding.guided_fields
import
LLMGuidedOptions
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
...
...
@@ -262,6 +265,8 @@ class LLM:
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
guided_options_request
:
Optional
[
Union
[
LLMGuidedOptions
,
GuidedDecodingRequest
]]
=
None
)
->
List
[
RequestOutput
]:
"""Generates the completions for the input prompts.
...
...
@@ -303,6 +308,14 @@ class LLM:
else
:
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
if
isinstance
(
guided_options_request
,
dict
):
if
len
(
guided_options_request
)
>
1
:
raise
ValueError
(
"You can only use one guided decoding but multiple is "
f
"specified:
{
guided_options_request
}
"
)
guided_options_request
=
GuidedDecodingRequest
(
**
guided_options_request
)
if
sampling_params
is
None
:
# Use default sampling params.
sampling_params
=
SamplingParams
()
...
...
@@ -311,7 +324,8 @@ class LLM:
inputs
=
inputs
,
params
=
sampling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
guided_options
=
guided_options_request
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
LLMEngine
.
validate_outputs
(
outputs
,
RequestOutput
)
...
...
@@ -508,6 +522,7 @@ class LLM:
Sequence
[
PoolingParams
]],
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
,
)
->
None
:
if
isinstance
(
inputs
,
(
str
,
dict
)):
# Convert a single prompt to a list.
...
...
@@ -523,6 +538,15 @@ class LLM:
raise
ValueError
(
"The lengths of prompts and lora_request "
"must be the same."
)
if
isinstance
(
params
,
list
):
params
=
[
self
.
_add_guided_processor
(
param
,
guided_options
)
if
isinstance
(
param
,
SamplingParams
)
else
param
for
param
in
params
]
elif
isinstance
(
params
,
SamplingParams
):
params
=
self
.
_add_guided_processor
(
params
,
guided_options
)
# Add requests to the engine.
for
i
,
request_inputs
in
enumerate
(
inputs
):
self
.
_add_request
(
...
...
@@ -548,6 +572,24 @@ class LLM:
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
def
_add_guided_processor
(
self
,
params
:
SamplingParams
,
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
):
if
guided_options
:
if
guided_options
.
guided_decoding_backend
is
None
:
decoding_config
=
self
.
llm_engine
.
get_decoding_config
()
guided_options
.
guided_decoding_backend
=
(
decoding_config
.
guided_decoding_backend
)
guided_logits_processor
=
get_local_guided_decoding_logits_processor
(
#noqa
guided_options
.
guided_decoding_backend
,
guided_options
,
self
.
get_tokenizer
())
if
guided_logits_processor
:
if
params
.
logits_processors
is
None
:
params
.
logits_processors
=
[]
params
.
logits_processors
.
append
(
guided_logits_processor
)
return
params
def
_run_engine
(
self
,
*
,
use_tqdm
:
bool
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
...
...
vllm/entrypoints/openai/api_server.py
View file @
e661d594
...
...
@@ -2,13 +2,13 @@ import asyncio
import
importlib
import
inspect
import
re
from
argparse
import
Namespace
from
contextlib
import
asynccontextmanager
from
http
import
HTTPStatus
from
typing
import
Optional
,
Set
from
multiprocessing
import
Process
from
typing
import
AsyncIterator
,
Set
import
fastapi
import
uvicorn
from
fastapi
import
APIRouter
,
Request
from
fastapi
import
APIRouter
,
FastAPI
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
...
...
@@ -16,8 +16,11 @@ from prometheus_client import make_asgi_app
from
starlette.routing
import
Mount
import
vllm.envs
as
envs
from
vllm.config
import
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.launcher
import
serve_http
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
# yapf conflicts with isort for this block
...
...
@@ -30,6 +33,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingRequest
,
ErrorResponse
,
TokenizeRequest
,
TokenizeResponse
)
from
vllm.entrypoints.openai.rpc.client
import
AsyncEngineRPCClient
from
vllm.entrypoints.openai.rpc.server
import
run_rpc_server
# yapf: enable
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
...
...
@@ -38,12 +43,12 @@ from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization
)
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
,
get_open_port
from
vllm.version
import
__version__
as
VLLM_VERSION
TIMEOUT_KEEP_ALIVE
=
5
# seconds
engine
:
Async
LLM
Engine
async_engine_client
:
AsyncEngine
Client
engine_args
:
AsyncEngineArgs
openai_serving_chat
:
OpenAIServingChat
openai_serving_completion
:
OpenAIServingCompletion
...
...
@@ -55,13 +60,22 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks
:
Set
[
asyncio
.
Task
]
=
set
()
def
model_is_embedding
(
model_name
:
str
,
trust_remote_code
:
bool
)
->
bool
:
return
ModelConfig
(
model
=
model_name
,
tokenizer
=
model_name
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
trust_remote_code
,
seed
=
0
,
dtype
=
"float16"
).
embedding_mode
@
asynccontextmanager
async
def
lifespan
(
app
:
fastapi
.
FastAPI
):
async
def
lifespan
(
app
:
FastAPI
):
async
def
_force_log
():
while
True
:
await
asyncio
.
sleep
(
10
)
await
engine
.
do_log_stats
()
await
async_engine_client
.
do_log_stats
()
if
not
engine_args
.
disable_log_stats
:
task
=
asyncio
.
create_task
(
_force_log
())
...
...
@@ -71,10 +85,56 @@ async def lifespan(app: fastapi.FastAPI):
yield
@
asynccontextmanager
async
def
build_async_engine_client
(
args
)
->
AsyncIterator
[
AsyncEngineClient
]:
# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
global
engine_args
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
# Backend itself still global for the silly lil' health handler
global
async_engine_client
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if
(
model_is_embedding
(
args
.
model
,
args
.
trust_remote_code
)
or
args
.
disable_frontend_multiprocessing
):
async_engine_client
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
)
yield
async_engine_client
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
else
:
# Start RPCServer in separate process (holds the AsyncLLMEngine).
port
=
get_open_port
(
envs
.
VLLM_RPC_PORT
)
rpc_server_process
=
Process
(
target
=
run_rpc_server
,
args
=
(
engine_args
,
UsageContext
.
OPENAI_API_SERVER
,
port
))
rpc_server_process
.
start
()
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client
=
AsyncEngineRPCClient
(
port
)
await
async_engine_client
.
setup
()
try
:
yield
async_engine_client
finally
:
# Ensure rpc server process was terminated
rpc_server_process
.
terminate
()
# Close all open connections to the backend
async_engine_client
.
close
()
# Wait for server process to join
rpc_server_process
.
join
()
router
=
APIRouter
()
def
mount_metrics
(
app
:
fastapi
.
FastAPI
):
def
mount_metrics
(
app
:
FastAPI
):
# Add prometheus asgi middleware to route /metrics requests
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
())
# Workaround for 307 Redirect for /metrics
...
...
@@ -85,7 +145,7 @@ def mount_metrics(app: fastapi.FastAPI):
@
router
.
get
(
"/health"
)
async
def
health
()
->
Response
:
"""Health check."""
await
openai_serving_chat
.
engine
.
check_health
()
await
async_engine_client
.
check_health
()
return
Response
(
status_code
=
200
)
...
...
@@ -164,8 +224,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
return
JSONResponse
(
content
=
generator
.
model_dump
())
def
build_app
(
args
)
:
app
=
fastapi
.
FastAPI
(
lifespan
=
lifespan
)
def
build_app
(
args
:
Namespace
)
->
FastAPI
:
app
=
FastAPI
(
lifespan
=
lifespan
)
app
.
include_router
(
router
)
app
.
root_path
=
args
.
root_path
...
...
@@ -213,37 +273,18 @@ def build_app(args):
return
app
def
run_server
(
args
,
llm_engine
=
None
):
async
def
init_app
(
async_engine_client
:
AsyncEngineClient
,
args
:
Namespace
,
)
->
FastAPI
:
app
=
build_app
(
args
)
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"args: %s"
,
args
)
if
args
.
served_model_name
is
not
None
:
served_model_names
=
args
.
served_model_name
else
:
served_model_names
=
[
args
.
model
]
global
engine
,
engine_args
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
(
llm_engine
if
llm_engine
is
not
None
else
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
))
event_loop
:
Optional
[
asyncio
.
AbstractEventLoop
]
try
:
event_loop
=
asyncio
.
get_running_loop
()
except
RuntimeError
:
event_loop
=
None
if
event_loop
is
not
None
and
event_loop
.
is_running
():
# If the current is instanced by Ray Serve,
# there is already a running event loop
model_config
=
event_loop
.
run_until_complete
(
engine
.
get_model_config
())
else
:
# When using single vLLM without engine_use_ray
model_config
=
asyncio
.
run
(
engine
.
get_model_config
())
model_config
=
await
async_engine_client
.
get_model_config
()
if
args
.
disable_log_requests
:
request_logger
=
None
...
...
@@ -256,7 +297,7 @@ def run_server(args, llm_engine=None):
global
openai_serving_tokenization
openai_serving_chat
=
OpenAIServingChat
(
engine
,
async_engine_client
,
model_config
,
served_model_names
,
args
.
response_role
,
...
...
@@ -264,23 +305,25 @@ def run_server(args, llm_engine=None):
prompt_adapters
=
args
.
prompt_adapters
,
request_logger
=
request_logger
,
chat_template
=
args
.
chat_template
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
)
openai_serving_completion
=
OpenAIServingCompletion
(
engine
,
async_engine_client
,
model_config
,
served_model_names
,
lora_modules
=
args
.
lora_modules
,
prompt_adapters
=
args
.
prompt_adapters
,
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
)
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine
,
async_engine_client
,
model_config
,
served_model_names
,
request_logger
=
request_logger
,
)
openai_serving_tokenization
=
OpenAIServingTokenization
(
engine
,
async_engine_client
,
model_config
,
served_model_names
,
lora_modules
=
args
.
lora_modules
,
...
...
@@ -289,22 +332,31 @@ def run_server(args, llm_engine=None):
)
app
.
root_path
=
args
.
root_path
logger
.
info
(
"Available routes are:"
)
for
route
in
app
.
routes
:
if
not
hasattr
(
route
,
'methods'
):
continue
methods
=
', '
.
join
(
route
.
methods
)
logger
.
info
(
"
Route: %s, Methods: %s"
,
route
.
path
,
method
s
)
return
app
async
def
run_server
(
args
,
**
uvicorn_kwargs
)
->
None
:
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"
args: %s"
,
arg
s
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
args
.
uvicorn_log_level
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
,
ssl_keyfile
=
args
.
ssl_keyfile
,
ssl_certfile
=
args
.
ssl_certfile
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
ssl_cert_reqs
=
args
.
ssl_cert_reqs
)
async
with
build_async_engine_client
(
args
)
as
async_engine_client
:
app
=
await
init_app
(
async_engine_client
,
args
)
shutdown_task
=
await
serve_http
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
args
.
uvicorn_log_level
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
,
ssl_keyfile
=
args
.
ssl_keyfile
,
ssl_certfile
=
args
.
ssl_certfile
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
ssl_cert_reqs
=
args
.
ssl_cert_reqs
,
**
uvicorn_kwargs
,
)
# NB: Await server shutdown only after the backend context is exited
await
shutdown_task
if
__name__
==
"__main__"
:
...
...
@@ -314,4 +366,5 @@ if __name__ == "__main__":
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
parser
=
make_arg_parser
(
parser
)
args
=
parser
.
parse_args
()
run_server
(
args
)
asyncio
.
run
(
run_server
(
args
))
vllm/entrypoints/openai/cli_args.py
View file @
e661d594
...
...
@@ -128,6 +128,17 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). "
)
parser
.
add_argument
(
"--return-tokens-as-token-ids"
,
action
=
"store_true"
,
help
=
"When --max-logprobs is specified, represents single tokens as "
"strings of the form 'token_id:{token_id}' so that tokens that "
"are not JSON-encodable can be identified."
)
parser
.
add_argument
(
"--disable-frontend-multiprocessing"
,
action
=
"store_true"
,
help
=
"If specified, will run the OpenAI frontend server in the same "
"process as the model serving engine."
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
...
...
vllm/entrypoints/openai/logits_processors.py
0 → 100644
View file @
e661d594
from
functools
import
lru_cache
,
partial
from
typing
import
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
,
Union
import
torch
from
transformers
import
PreTrainedTokenizer
from
vllm.sampling_params
import
LogitsProcessor
class
AllowedTokenIdsLogitsProcessor
:
"""Logits processor for constraining generated tokens to a
specific set of token ids."""
def
__init__
(
self
,
allowed_ids
:
Iterable
[
int
]):
self
.
allowed_ids
:
Optional
[
List
[
int
]]
=
list
(
allowed_ids
)
self
.
mask
:
Optional
[
torch
.
Tensor
]
=
None
def
__call__
(
self
,
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
mask
is
None
:
self
.
mask
=
torch
.
ones
((
logits
.
shape
[
-
1
],
),
dtype
=
torch
.
bool
,
device
=
logits
.
device
)
self
.
mask
[
self
.
allowed_ids
]
=
False
self
.
allowed_ids
=
None
logits
.
masked_fill_
(
self
.
mask
,
float
(
"-inf"
))
return
logits
@
lru_cache
(
maxsize
=
32
)
def
_get_allowed_token_ids_logits_processor
(
allowed_token_ids
:
FrozenSet
[
int
],
vocab_size
:
int
,
)
->
LogitsProcessor
:
if
not
allowed_token_ids
:
raise
ValueError
(
"Empty allowed_token_ids provided"
)
if
not
all
(
0
<=
tid
<
vocab_size
for
tid
in
allowed_token_ids
):
raise
ValueError
(
"allowed_token_ids contains "
"out-of-vocab token id"
)
return
AllowedTokenIdsLogitsProcessor
(
allowed_token_ids
)
def
logit_bias_logits_processor
(
logit_bias
:
Dict
[
str
,
float
],
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
token_id
,
bias
in
logit_bias
.
items
():
logits
[
token_id
]
+=
bias
return
logits
def
get_logits_processors
(
logit_bias
:
Optional
[
Union
[
Dict
[
int
,
float
],
Dict
[
str
,
float
]]],
allowed_token_ids
:
Optional
[
List
[
int
]],
tokenizer
:
PreTrainedTokenizer
)
->
List
[
LogitsProcessor
]:
logits_processors
=
[]
if
logit_bias
:
try
:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
clamped_logit_bias
:
Dict
[
int
,
float
]
=
{
int
(
token_id
):
min
(
100.0
,
max
(
-
100.0
,
bias
))
for
token_id
,
bias
in
logit_bias
.
items
()
}
except
ValueError
as
exc
:
raise
ValueError
(
"Found token_id in logit_bias that is not "
"an integer or string representing an integer"
)
from
exc
# Check if token_id is within the vocab size
for
token_id
,
bias
in
clamped_logit_bias
.
items
():
if
token_id
<
0
or
token_id
>=
tokenizer
.
vocab_size
:
raise
ValueError
(
"token_id in logit_bias contains "
"out-of-vocab token id"
)
logits_processors
.
append
(
partial
(
logit_bias_logits_processor
,
clamped_logit_bias
))
if
allowed_token_ids
is
not
None
:
logits_processors
.
append
(
_get_allowed_token_ids_logits_processor
(
frozenset
(
allowed_token_ids
),
tokenizer
.
vocab_size
))
return
logits_processors
vllm/entrypoints/openai/protocol.py
View file @
e661d594
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import
time
from
argparse
import
Namespace
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
import
torch
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
model_validator
from
transformers
import
PreTrainedTokenizer
from
typing_extensions
import
Annotated
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.entrypoints.openai.logits_processors
import
get_logits_processors
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.utils
import
random_uuid
# torch is mocked during docs generation,
# so we have to provide the values as literals
_MOCK_LONG_INFO
=
Namespace
(
min
=-
9223372036854775808
,
max
=
9223372036854775807
)
try
:
from
sphinx.ext.autodoc.mock
import
_MockModule
if
isinstance
(
torch
,
_MockModule
):
_LONG_INFO
=
_MOCK_LONG_INFO
else
:
_LONG_INFO
=
torch
.
iinfo
(
torch
.
long
)
except
ModuleNotFoundError
:
_LONG_INFO
=
torch
.
iinfo
(
torch
.
long
)
assert
_LONG_INFO
.
min
==
_MOCK_LONG_INFO
.
min
assert
_LONG_INFO
.
max
==
_MOCK_LONG_INFO
.
max
class
OpenAIBaseModel
(
BaseModel
):
# OpenAI API does not allow extra fields
...
...
@@ -106,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
n
:
Optional
[
int
]
=
1
presence_penalty
:
Optional
[
float
]
=
0.0
response_format
:
Optional
[
ResponseFormat
]
=
None
seed
:
Optional
[
int
]
=
Field
(
None
,
ge
=
torch
.
iinfo
(
torch
.
long
).
min
,
le
=
torch
.
iinfo
(
torch
.
long
).
max
)
seed
:
Optional
[
int
]
=
Field
(
None
,
ge
=
_LONG_INFO
.
min
,
le
=
_LONG_INFO
.
max
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
stream_options
:
Optional
[
StreamOptions
]
=
None
...
...
@@ -213,30 +231,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params
def
to_sampling_params
(
self
)
->
SamplingParams
:
# We now allow logprobs being true without top_logrobs.
def
to_sampling_params
(
self
,
tokenizer
:
PreTrainedTokenizer
,
guided_decode_logits_processor
:
Optional
[
LogitsProcessor
],
default_max_tokens
:
int
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
logits_processors
=
None
if
self
.
logit_bias
:
logit_bias
:
Dict
[
int
,
float
]
=
{}
try
:
for
token_id
,
bias
in
self
.
logit_bias
.
items
():
# Convert token_id to integer before we add to LLMEngine
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias
[
int
(
token_id
)]
=
min
(
100
,
max
(
-
100
,
bias
))
except
ValueError
as
exc
:
raise
ValueError
(
f
"Found token_id `
{
token_id
}
` in logit_bias "
f
"but token_id must be an integer or string "
f
"representing an integer"
)
from
exc
def
logit_bias_logits_processor
(
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
token_id
,
bias
in
logit_bias
.
items
():
logits
[
token_id
]
+=
bias
return
logits
logits_processors
=
[
logit_bias_logits_processor
]
# We now allow logprobs being true without top_logrobs.
logits_processors
=
get_logits_processors
(
logit_bias
=
self
.
logit_bias
,
allowed_token_ids
=
None
,
tokenizer
=
tokenizer
,
)
if
guided_decode_logits_processor
:
logits_processors
.
append
(
guided_decode_logits_processor
)
return
SamplingParams
(
n
=
self
.
n
,
...
...
@@ -254,7 +264,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
logprobs
=
self
.
top_logprobs
if
self
.
logprobs
else
None
,
prompt_logprobs
=
self
.
top_logprobs
if
self
.
echo
else
None
,
ignore_eos
=
self
.
ignore_eos
,
max_tokens
=
self
.
max_tokens
,
max_tokens
=
max_tokens
,
min_tokens
=
self
.
min_tokens
,
use_beam_search
=
self
.
use_beam_search
,
early_stopping
=
self
.
early_stopping
,
...
...
@@ -333,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel):
max_tokens
:
Optional
[
int
]
=
16
n
:
int
=
1
presence_penalty
:
Optional
[
float
]
=
0.0
seed
:
Optional
[
int
]
=
Field
(
None
,
ge
=
torch
.
iinfo
(
torch
.
long
).
min
,
le
=
torch
.
iinfo
(
torch
.
long
).
max
)
seed
:
Optional
[
int
]
=
Field
(
None
,
ge
=
_LONG_INFO
.
min
,
le
=
_LONG_INFO
.
max
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
stream_options
:
Optional
[
StreamOptions
]
=
None
...
...
@@ -358,6 +366,7 @@ class CompletionRequest(OpenAIBaseModel):
skip_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
allowed_token_ids
:
Optional
[
List
[
int
]]
=
None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
...
...
@@ -407,30 +416,23 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params
def
to_sampling_params
(
self
):
def
to_sampling_params
(
self
,
tokenizer
:
PreTrainedTokenizer
,
guided_decode_logits_processor
:
Optional
[
LogitsProcessor
],
default_max_tokens
:
int
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
logits_processors
=
None
if
self
.
logit_bias
:
logit_bias
:
Dict
[
int
,
float
]
=
{}
try
:
for
token_id
,
bias
in
self
.
logit_bias
.
items
():
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias
[
int
(
token_id
)]
=
min
(
100
,
max
(
-
100
,
bias
))
except
ValueError
as
exc
:
raise
ValueError
(
f
"Found token_id `
{
token_id
}
` in logit_bias "
f
"but token_id must be an integer or string "
f
"representing an integer"
)
from
exc
def
logit_bias_logits_processor
(
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
token_id
,
bias
in
logit_bias
.
items
():
logits
[
token_id
]
+=
bias
return
logits
logits_processors
=
[
logit_bias_logits_processor
]
logits_processors
=
get_logits_processors
(
logit_bias
=
self
.
logit_bias
,
allowed_token_ids
=
self
.
allowed_token_ids
,
tokenizer
=
tokenizer
,
)
if
guided_decode_logits_processor
:
logits_processors
.
append
(
guided_decode_logits_processor
)
return
SamplingParams
(
n
=
self
.
n
,
...
...
@@ -447,7 +449,7 @@ class CompletionRequest(OpenAIBaseModel):
stop_token_ids
=
self
.
stop_token_ids
,
logprobs
=
self
.
logprobs
,
ignore_eos
=
self
.
ignore_eos
,
max_tokens
=
self
.
max_tokens
if
not
echo_without_generation
else
1
,
max_tokens
=
max_tokens
if
not
echo_without_generation
else
1
,
min_tokens
=
self
.
min_tokens
,
use_beam_search
=
self
.
use_beam_search
,
early_stopping
=
self
.
early_stopping
,
...
...
vllm/entrypoints/openai/rpc/__init__.py
0 → 100644
View file @
e661d594
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
Mapping
,
Optional
,
Union
from
vllm.inputs
import
PromptInputs
from
vllm.lora.request
import
LoRARequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
VLLM_RPC_HEALTHY_STR
=
"HEALTHY"
@
dataclass
class
RPCGenerateRequest
:
inputs
:
PromptInputs
sampling_params
:
SamplingParams
request_id
:
str
lora_request
:
Optional
[
LoRARequest
]
=
None
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
@
dataclass
class
RPCAbortRequest
:
request_id
:
str
class
RPCUtilityRequest
(
Enum
):
IS_SERVER_READY
=
1
GET_MODEL_CONFIG
=
2
GET_DECODING_CONFIG
=
3
GET_PARALLEL_CONFIG
=
4
GET_SCHEDULER_CONFIG
=
5
GET_LORA_CONFIG
=
6
DO_LOG_STATS
=
7
CHECK_HEALTH
=
8
IS_TRACING_ENABLED
=
9
RPC_REQUEST_TYPE
=
Union
[
RPCGenerateRequest
,
RPCAbortRequest
,
RPCUtilityRequest
]
vllm/entrypoints/openai/rpc/client.py
0 → 100644
View file @
e661d594
from
contextlib
import
contextmanager
from
typing
import
Any
,
AsyncIterator
,
Mapping
,
Optional
import
cloudpickle
import
zmq
import
zmq.asyncio
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.entrypoints.openai.rpc
import
(
RPC_REQUEST_TYPE
,
VLLM_RPC_HEALTHY_STR
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCGenerateRequest
,
RPCUtilityRequest
)
from
vllm.inputs
import
PromptInputs
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
class
AsyncEngineRPCClient
:
def
__init__
(
self
,
port
:
int
):
self
.
context
=
zmq
.
asyncio
.
Context
()
self
.
path
=
f
"tcp://localhost:
{
port
}
"
async
def
setup
(
self
):
"""Setup the client before it starts sending server requests."""
# Wait until server is ready.
await
self
.
wait_for_server
()
# Get the configs.
self
.
model_config
=
await
self
.
_get_model_config_rpc
()
self
.
decoding_config
=
await
self
.
_get_decoding_config_rpc
()
self
.
tracing_flag
=
await
self
.
_is_tracing_enabled_rpc
()
# Create the tokenizer group.
# TODO: refactor OAI server to avoid needing this info.
self
.
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
self
.
model_config
,
scheduler_config
=
(
await
self
.
_get_scheduler_config_rpc
()),
parallel_config
=
(
await
self
.
_get_parallel_config_rpc
()),
enable_lora
=
bool
(
await
self
.
_get_lora_config_rpc
()),
)
def
close
(
self
):
"""Destroy the ZeroMQ Context."""
self
.
context
.
destroy
()
@
contextmanager
def
socket
(
self
):
# Ensure client sockets are always closed after use
# Connect to RPC socket for Request-Reply pattern,
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
try
:
socket
.
connect
(
self
.
path
)
yield
socket
finally
:
socket
.
close
()
async
def
_send_get_data_rpc_request
(
self
,
request
:
RPCUtilityRequest
,
expected_type
:
Any
,
error_message
:
str
)
->
Any
:
"""Send an RPC request that is expecting data back."""
with
self
.
socket
()
as
socket
:
# Ping RPCServer with a request.
await
socket
.
send
(
cloudpickle
.
dumps
(
request
))
# Await the data from the Server.
data
=
cloudpickle
.
loads
(
await
socket
.
recv
())
if
not
isinstance
(
data
,
expected_type
):
# LoRAConfig can be None.
if
expected_type
==
LoRAConfig
and
data
is
None
:
pass
else
:
raise
ValueError
(
error_message
)
return
data
async
def
_send_one_way_rpc_request
(
self
,
request
:
RPC_REQUEST_TYPE
,
error_message
:
str
):
"""Send one-way RPC request to trigger an action."""
with
self
.
socket
()
as
socket
:
# Ping RPC Server with request.
await
socket
.
send
(
cloudpickle
.
dumps
(
request
))
# Await acknowledgement from RPCServer.
response
=
cloudpickle
.
loads
(
await
socket
.
recv
())
if
not
isinstance
(
response
,
str
)
or
response
!=
VLLM_RPC_SUCCESS_STR
:
raise
ValueError
(
error_message
)
return
response
async
def
get_tokenizer
(
self
,
lora_request
:
LoRARequest
):
return
await
self
.
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
return
self
.
decoding_config
async
def
get_model_config
(
self
)
->
ModelConfig
:
return
self
.
model_config
async
def
is_tracing_enabled
(
self
)
->
bool
:
return
self
.
tracing_flag
async
def
wait_for_server
(
self
):
"""Wait for the RPCServer to start up."""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
IS_SERVER_READY
,
error_message
=
"Unable to start RPC Server."
)
async
def
_get_model_config_rpc
(
self
)
->
ModelConfig
:
"""Get the ModelConfig object from the RPC Server"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_MODEL_CONFIG
,
expected_type
=
ModelConfig
,
error_message
=
"Could not get ModelConfig from RPC Server"
)
async
def
_get_decoding_config_rpc
(
self
)
->
DecodingConfig
:
"""Get DecodingConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_DECODING_CONFIG
,
expected_type
=
DecodingConfig
,
error_message
=
"Could not get DecodingConfig from RPC Server"
)
async
def
_get_parallel_config_rpc
(
self
)
->
ParallelConfig
:
"""Get ParallelConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_PARALLEL_CONFIG
,
expected_type
=
ParallelConfig
,
error_message
=
"Could not get ParallelConfig from RPC Server"
)
async
def
_get_scheduler_config_rpc
(
self
)
->
SchedulerConfig
:
"""Get SchedulerConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_SCHEDULER_CONFIG
,
expected_type
=
SchedulerConfig
,
error_message
=
"Could not get SchedulerConfig from RPC Server"
)
async
def
_get_lora_config_rpc
(
self
):
"""Get LoRAConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_LORA_CONFIG
,
expected_type
=
LoRAConfig
,
error_message
=
"Could not get LoRAConfig from RPC Server"
)
async
def
_is_tracing_enabled_rpc
(
self
)
->
ParallelConfig
:
"""Get is_tracing_enabled flag from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
IS_TRACING_ENABLED
,
expected_type
=
bool
,
error_message
=
"Could not get is_tracing_enabled flag from RPC "
"Server"
)
async
def
abort
(
self
,
request_id
:
str
):
"""Send an ABORT_REQUEST signal to the RPC Server"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCAbortRequest
(
request_id
),
error_message
=
f
"RPCAbortRequest
{
request_id
}
failed"
)
async
def
do_log_stats
(
self
):
"""Send a DO_LOG_STATS signal to the RPC Server"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
DO_LOG_STATS
,
error_message
=
"RPCRequest DO_LOG_STATS failed."
)
async
def
generate
(
self
,
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
AsyncIterator
[
RequestOutput
]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
with
self
.
socket
()
as
socket
:
# Send RPCGenerateRequest to the RPCServer.
await
socket
.
send_multipart
([
cloudpickle
.
dumps
(
RPCGenerateRequest
(
inputs
=
inputs
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
))
])
# Stream back the results from the RPC Server.
while
True
:
message
=
await
socket
.
recv
()
request_output
=
cloudpickle
.
loads
(
message
)
if
isinstance
(
request_output
,
Exception
):
raise
request_output
if
request_output
.
finished
:
break
yield
request_output
yield
request_output
async
def
check_health
(
self
)
->
None
:
"""Raise if unhealthy"""
with
self
.
socket
()
as
socket
:
# Ping RPCServer with CHECK_HEALTH request.
await
socket
.
send
(
cloudpickle
.
dumps
(
RPCUtilityRequest
.
CHECK_HEALTH
)
)
# Await the reply from the server.
# TODO: do we need an internal timeout here?
# Or do we expect the external probe to timeout and let this chill?
health_message
=
cloudpickle
.
loads
(
await
socket
.
recv
())
if
isinstance
(
health_message
,
Exception
):
raise
health_message
if
health_message
!=
VLLM_RPC_HEALTHY_STR
:
raise
ValueError
(
"Expected healthy response from backend but got "
"f{health_message}"
)
async
def
encode
(
self
,
*
args
,
**
kwargs
)
->
AsyncIterator
[
EmbeddingRequestOutput
]:
raise
NotImplementedError
(
"Embeddings not supported with multiprocessing backend"
)
vllm/entrypoints/openai/rpc/server.py
0 → 100644
View file @
e661d594
import
asyncio
import
signal
from
typing
import
Any
,
Coroutine
import
cloudpickle
import
zmq
import
zmq.asyncio
from
typing_extensions
import
Never
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
from
vllm.entrypoints.openai.rpc
import
(
VLLM_RPC_HEALTHY_STR
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCGenerateRequest
,
RPCUtilityRequest
)
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
logger
=
init_logger
(
__name__
)
class
AsyncEngineRPCServer
:
def
__init__
(
self
,
async_engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
port
:
int
):
# Initialize engine first.
self
.
engine
=
AsyncLLMEngine
.
from_engine_args
(
async_engine_args
,
usage_context
)
# Initialize context.
self
.
context
=
zmq
.
asyncio
.
Context
()
# Init socket for readiness state.
self
.
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
ROUTER
)
# Note numeric form of localhost should be used for zmq bind(),
# see https://stackoverflow.com/a/8958414
self
.
socket
.
bind
(
f
"tcp://127.0.0.1:
{
port
}
"
)
def
cleanup
(
self
):
"""Cleanup all resources."""
self
.
socket
.
close
()
self
.
context
.
destroy
()
async
def
get_model_config
(
self
,
identity
):
"""Send the ModelConfig"""
model_config
=
await
self
.
engine
.
get_model_config
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
model_config
)])
async
def
get_decoding_config
(
self
,
identity
):
"""Send the DecodingConfig"""
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
decoding_config
)])
async
def
get_lora_config
(
self
,
identity
):
lora_config
=
await
self
.
engine
.
get_lora_config
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
lora_config
)])
async
def
get_scheduler_config
(
self
,
identity
):
"""Send the SchedulerConfig"""
parallel_config
=
await
self
.
engine
.
get_scheduler_config
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
parallel_config
)])
async
def
get_parallel_config
(
self
,
identity
):
"""Send the ParallelConfig"""
parallel_config
=
await
self
.
engine
.
get_parallel_config
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
parallel_config
)])
async
def
is_tracing_enabled
(
self
,
identity
):
"""Send the is_tracing_enabled flag"""
tracing_flag
=
await
self
.
engine
.
is_tracing_enabled
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
tracing_flag
)])
async
def
do_log_stats
(
self
,
identity
):
"""Log stats and confirm success."""
await
self
.
engine
.
do_log_stats
()
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
])
async
def
is_server_ready
(
self
,
identity
):
"""Notify the client that we are ready."""
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
])
async
def
abort
(
self
,
identity
,
request
:
RPCAbortRequest
):
"""Abort request and notify the client of success."""
# Abort the request in the llm engine.
await
self
.
engine
.
abort
(
request
.
request_id
)
# Send confirmation to the client.
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
])
async
def
generate
(
self
,
identity
,
generate_request
:
RPCGenerateRequest
):
try
:
results_generator
=
self
.
engine
.
generate
(
generate_request
.
inputs
,
sampling_params
=
generate_request
.
sampling_params
,
request_id
=
generate_request
.
request_id
,
lora_request
=
generate_request
.
lora_request
,
trace_headers
=
generate_request
.
trace_headers
,
prompt_adapter_request
=
generate_request
.
prompt_adapter_request
)
async
for
request_output
in
results_generator
:
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
request_output
)])
except
Exception
as
e
:
### Notify client of all failures
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
e
)])
async
def
check_health
(
self
,
identity
):
try
:
await
self
.
engine
.
check_health
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_HEALTHY_STR
)])
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
e
)])
def
_make_handler_coro
(
self
,
identity
,
message
)
->
Coroutine
[
Any
,
Any
,
Never
]:
"""Route the zmq message to the handler coroutine."""
request
=
cloudpickle
.
loads
(
message
)
if
isinstance
(
request
,
RPCGenerateRequest
):
return
self
.
generate
(
identity
,
request
)
elif
isinstance
(
request
,
RPCAbortRequest
):
return
self
.
abort
(
identity
,
request
)
elif
isinstance
(
request
,
RPCUtilityRequest
):
if
request
==
RPCUtilityRequest
.
GET_MODEL_CONFIG
:
return
self
.
get_model_config
(
identity
)
elif
request
==
RPCUtilityRequest
.
GET_PARALLEL_CONFIG
:
return
self
.
get_parallel_config
(
identity
)
elif
request
==
RPCUtilityRequest
.
GET_DECODING_CONFIG
:
return
self
.
get_decoding_config
(
identity
)
elif
request
==
RPCUtilityRequest
.
GET_SCHEDULER_CONFIG
:
return
self
.
get_scheduler_config
(
identity
)
elif
request
==
RPCUtilityRequest
.
GET_LORA_CONFIG
:
return
self
.
get_lora_config
(
identity
)
elif
request
==
RPCUtilityRequest
.
DO_LOG_STATS
:
return
self
.
do_log_stats
(
identity
)
elif
request
==
RPCUtilityRequest
.
IS_SERVER_READY
:
return
self
.
is_server_ready
(
identity
)
elif
request
==
RPCUtilityRequest
.
CHECK_HEALTH
:
return
self
.
check_health
(
identity
)
elif
request
==
RPCUtilityRequest
.
IS_TRACING_ENABLED
:
return
self
.
is_tracing_enabled
(
identity
)
else
:
raise
ValueError
(
f
"Unknown RPCUtilityRequest type:
{
request
}
"
)
else
:
raise
ValueError
(
f
"Unknown RPCRequest type:
{
request
}
"
)
async
def
run_server_loop
(
self
):
"""Inner RPC Server Loop"""
running_tasks
=
set
()
while
True
:
# Wait for a request.
identity
,
message
=
await
self
.
socket
.
recv_multipart
()
# Process the request async.
task
=
asyncio
.
create_task
(
self
.
_make_handler_coro
(
identity
,
message
))
# We need to keep around a strong reference to the task,
# to avoid the task disappearing mid-execution as running tasks
# can be GC'ed. Below is a common "fire-and-forget" tasks
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
running_tasks
.
add
(
task
)
task
.
add_done_callback
(
running_tasks
.
discard
)
async
def
run_server
(
server
:
AsyncEngineRPCServer
):
# Put the server task into the asyncio loop.
loop
=
asyncio
.
get_running_loop
()
server_task
=
loop
.
create_task
(
server
.
run_server_loop
())
# Interruption handling.
def
signal_handler
()
->
None
:
# Kill the server on interrupt / terminate
server_task
.
cancel
()
loop
.
add_signal_handler
(
signal
.
SIGINT
,
signal_handler
)
loop
.
add_signal_handler
(
signal
.
SIGTERM
,
signal_handler
)
try
:
await
server_task
except
asyncio
.
CancelledError
:
logger
.
info
(
"vLLM ZMQ RPC Server was interrupted."
)
finally
:
# Clean up all resources.
server
.
cleanup
()
def
run_rpc_server
(
async_engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
port
:
int
):
server
=
AsyncEngineRPCServer
(
async_engine_args
,
usage_context
,
port
)
asyncio
.
run
(
run_server
(
server
))
vllm/entrypoints/openai/serving_chat.py
View file @
e661d594
import
time
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Awaitable
,
Dict
,
List
,
Optional
)
from
typing
import
AsyncGenerator
,
AsyncIterator
,
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
...
...
@@ -8,10 +7,10 @@ from fastapi import Request
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
ModelConfig
from
vllm.engine.
async_llm_engine
import
Async
LLM
Engine
from
vllm.engine.
protocol
import
AsyncEngine
Client
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
load_chat_template
,
parse_chat_message
_content
)
parse_chat_message
s
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionLogProb
,
ChatCompletionLogProbs
,
...
...
@@ -25,8 +24,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath
)
from
vllm.inputs
import
PromptInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
Logprob
...
...
@@ -41,7 +38,7 @@ class OpenAIServingChat(OpenAIServing):
def
__init__
(
self
,
engine
:
Async
LLM
Engine
,
async_engine_client
:
AsyncEngine
Client
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
response_role
:
str
,
...
...
@@ -50,13 +47,15 @@ class OpenAIServingChat(OpenAIServing):
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
request_logger
:
Optional
[
RequestLogger
],
chat_template
:
Optional
[
str
],
return_tokens_as_token_ids
:
bool
=
False
,
):
super
().
__init__
(
engine
=
engine
,
super
().
__init__
(
async_engine_client
=
async_engine_client
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
lora_modules
=
lora_modules
,
prompt_adapters
=
prompt_adapters
,
request_logger
=
request_logger
)
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
return_tokens_as_token_ids
)
self
.
response_role
=
response_role
...
...
@@ -89,17 +88,11 @@ class OpenAIServingChat(OpenAIServing):
)
=
self
.
_maybe_get_adapters
(
request
)
model_config
=
self
.
model_config
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
conversation
:
List
[
ConversationMessage
]
=
[]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
[]
for
msg
in
request
.
messages
:
chat_parsed_result
=
parse_chat_message_content
(
msg
,
model_config
,
tokenizer
)
conversation
.
extend
(
chat_parsed_result
.
messages
)
mm_futures
.
extend
(
chat_parsed_result
.
mm_futures
)
conversation
,
mm_futures
=
parse_chat_messages
(
request
.
messages
,
model_config
,
tokenizer
)
tool_dicts
=
None
if
request
.
tools
is
None
else
[
tool
.
model_dump
()
for
tool
in
request
.
tools
...
...
@@ -114,6 +107,7 @@ class OpenAIServingChat(OpenAIServing):
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
**
(
request
.
chat_template_kwargs
or
{}),
)
assert
isinstance
(
prompt
,
str
)
except
Exception
as
e
:
logger
.
error
(
"Error in applying chat template from request: %s"
,
e
)
return
self
.
create_error_response
(
str
(
e
))
...
...
@@ -132,28 +126,23 @@ class OpenAIServingChat(OpenAIServing):
request_id
=
f
"chat-
{
random_uuid
()
}
"
try
:
sampling_params
=
request
.
to_sampling_params
()
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
guided_decode_logits_processor
=
(
await
get_guided_decoding_logits_processor
(
guided_decoding_backend
,
request
,
tokenizer
))
if
guided_decode_logits_processor
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
sampling_params
.
logits_processors
.
append
(
guided_decode_logits_processor
)
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
prompt_inputs
=
self
.
_tokenize_prompt_input
(
request
,
tokenizer
,
prompt
,
truncate_prompt_tokens
=
sampling_params
.
truncate_prompt_tokens
,
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
)
sampling_params
=
request
.
to_sampling_params
(
tokenizer
,
guided_decode_logits_processor
,
default_max_tokens
=
self
.
max_model_len
-
len
(
prompt_inputs
[
"prompt_token_ids"
]))
self
.
_log_inputs
(
request_id
,
prompt_inputs
,
params
=
sampling_params
,
...
...
@@ -166,7 +155,8 @@ class OpenAIServingChat(OpenAIServing):
if
mm_data
is
not
None
:
engine_inputs
[
"multi_modal_data"
]
=
mm_data
is_tracing_enabled
=
await
self
.
engine
.
is_tracing_enabled
()
is_tracing_enabled
=
(
await
self
.
async_engine_client
.
is_tracing_enabled
())
trace_headers
=
None
if
is_tracing_enabled
and
raw_request
:
trace_headers
=
extract_trace_headers
(
raw_request
.
headers
)
...
...
@@ -174,7 +164,7 @@ class OpenAIServingChat(OpenAIServing):
and
contains_trace_headers
(
raw_request
.
headers
)):
log_tracing_disabled_warning
()
result_generator
=
self
.
engine
.
generate
(
result_generator
=
self
.
async_engine_client
.
generate
(
engine_inputs
,
sampling_params
,
request_id
,
...
...
@@ -247,7 +237,15 @@ class OpenAIServingChat(OpenAIServing):
model
=
model_name
)
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
chunk
.
usage
=
None
if
(
request
.
stream_options
.
continuous_usage_stats
):
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
0
,
total_tokens
=
prompt_tokens
)
chunk
.
usage
=
usage
else
:
chunk
.
usage
=
None
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
...
...
@@ -277,7 +275,18 @@ class OpenAIServingChat(OpenAIServing):
model
=
model_name
)
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
chunk
.
usage
=
None
if
(
request
.
stream_options
.
continuous_usage_stats
):
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
0
,
total_tokens
=
prompt_tokens
)
chunk
.
usage
=
usage
else
:
chunk
.
usage
=
None
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
...
...
@@ -336,7 +345,19 @@ class OpenAIServingChat(OpenAIServing):
model
=
model_name
)
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
chunk
.
usage
=
None
if
(
request
.
stream_options
.
continuous_usage_stats
):
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
)
chunk
.
usage
=
usage
else
:
chunk
.
usage
=
None
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
else
:
...
...
@@ -356,7 +377,18 @@ class OpenAIServingChat(OpenAIServing):
model
=
model_name
)
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
chunk
.
usage
=
None
if
(
request
.
stream_options
.
continuous_usage_stats
):
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
)
chunk
.
usage
=
usage
else
:
chunk
.
usage
=
None
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
finish_reason_sent
[
i
]
=
True
...
...
@@ -404,7 +436,7 @@ class OpenAIServingChat(OpenAIServing):
async
for
res
in
result_generator
:
if
raw_request
is
not
None
and
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
self
.
engine
.
abort
(
request_id
)
await
self
.
async_engine_client
.
abort
(
request_id
)
return
self
.
create_error_response
(
"Client disconnected"
)
final_res
=
res
assert
final_res
is
not
None
...
...
@@ -480,11 +512,14 @@ class OpenAIServingChat(OpenAIServing):
self
,
logprobs
:
Dict
[
int
,
Logprob
],
top_logprobs
:
Optional
[
int
],
tokenizer
:
PreTrainedTokenizer
)
->
List
[
ChatCompletionLogProb
]:
return
[
ChatCompletionLogProb
(
token
=
(
token
:
=
self
.
_get_decoded_token
(
p
[
1
],
p
[
0
],
tokenizer
)),
logprob
=
max
(
p
[
1
].
logprob
,
-
9999.0
),
bytes
=
list
(
token
.
encode
(
"utf-8"
,
errors
=
"replace"
)))
ChatCompletionLogProb
(
token
=
(
token
:
=
self
.
_get_decoded_token
(
p
[
1
],
p
[
0
],
tokenizer
,
return_as_token_id
=
self
.
return_tokens_as_token_ids
)),
logprob
=
max
(
p
[
1
].
logprob
,
-
9999.0
),
bytes
=
list
(
token
.
encode
(
"utf-8"
,
errors
=
"replace"
)))
for
i
,
p
in
enumerate
(
logprobs
.
items
())
if
top_logprobs
and
i
<
top_logprobs
]
...
...
@@ -504,6 +539,8 @@ class OpenAIServingChat(OpenAIServing):
step_top_logprobs
=
top_logprobs
[
i
]
if
step_top_logprobs
is
None
:
token
=
tokenizer
.
decode
(
token_id
)
if
self
.
return_tokens_as_token_ids
:
token
=
f
"token_id:
{
token_id
}
"
logprobs_content
.
append
(
ChatCompletionLogProbsContent
(
token
=
token
,
...
...
@@ -511,7 +548,9 @@ class OpenAIServingChat(OpenAIServing):
else
:
logprobs_content
.
append
(
ChatCompletionLogProbsContent
(
token
=
step_top_logprobs
[
token_id
].
decoded_token
,
token
=
self
.
_get_decoded_token
(
step_top_logprobs
[
token_id
],
token_id
,
tokenizer
,
self
.
return_tokens_as_token_ids
),
logprob
=
max
(
step_top_logprobs
[
token_id
].
logprob
,
-
9999.0
),
bytes
=
list
(
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
e661d594
...
...
@@ -8,7 +8,7 @@ from fastapi import Request
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
ModelConfig
from
vllm.engine.
async_llm_engine
import
Async
LLM
Engine
from
vllm.engine.
protocol
import
AsyncEngine
Client
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
...
...
@@ -24,8 +24,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing
,
PromptAdapterPath
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
...
...
@@ -44,20 +42,22 @@ class OpenAIServingCompletion(OpenAIServing):
def
__init__
(
self
,
engine
:
Async
LLM
Engine
,
async_engine_client
:
AsyncEngine
Client
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
request_logger
:
Optional
[
RequestLogger
],
return_tokens_as_token_ids
:
bool
=
False
,
):
super
().
__init__
(
engine
=
engine
,
super
().
__init__
(
async_engine_client
=
async_engine_client
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
lora_modules
=
lora_modules
,
prompt_adapters
=
prompt_adapters
,
request_logger
=
request_logger
)
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
return_tokens_as_token_ids
)
async
def
create_completion
(
self
,
request
:
CompletionRequest
,
raw_request
:
Request
):
...
...
@@ -91,33 +91,27 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
sampling_params
=
request
.
to_sampling_params
()
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
guided_decode_logit_processor
=
(
await
get_guided_decoding_logits_processor
(
guided_decoding_backend
,
request
,
tokenizer
))
if
guided_decode_logit_processor
is
not
None
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
sampling_params
.
logits_processors
.
append
(
guided_decode_logit_processor
)
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
guided_decode_logits_processor
=
(
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
prompts
=
list
(
self
.
_tokenize_prompt_input_or_inputs
(
request
,
tokenizer
,
request
.
prompt
,
truncate_prompt_tokens
=
sampling_params
.
truncate_prompt_tokens
,
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
))
for
i
,
prompt_inputs
in
enumerate
(
prompts
):
sampling_params
=
request
.
to_sampling_params
(
tokenizer
,
guided_decode_logits_processor
,
default_max_tokens
=
self
.
max_model_len
-
len
(
prompt_inputs
[
"prompt_token_ids"
]))
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
self
.
_log_inputs
(
request_id_item
,
...
...
@@ -126,7 +120,8 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
is_tracing_enabled
=
await
self
.
engine
.
is_tracing_enabled
()
is_tracing_enabled
=
(
await
self
.
async_engine_client
.
is_tracing_enabled
())
trace_headers
=
None
if
is_tracing_enabled
:
trace_headers
=
extract_trace_headers
(
raw_request
.
headers
)
...
...
@@ -134,7 +129,7 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request
.
headers
):
log_tracing_disabled_warning
()
generator
=
self
.
engine
.
generate
(
generator
=
self
.
async_engine_client
.
generate
(
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]},
sampling_params
,
request_id_item
,
...
...
@@ -175,7 +170,7 @@ class OpenAIServingCompletion(OpenAIServing):
async
for
i
,
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
await
self
.
async_engine_client
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
return
self
.
create_error_response
(
"Client disconnected"
)
final_res_batch
[
i
]
=
res
...
...
@@ -237,7 +232,8 @@ class OpenAIServingCompletion(OpenAIServing):
# Abort the request if the client disconnects.
if
await
raw_request
.
is_disconnected
():
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
prompt_idx
}
"
)
await
self
.
async_engine_client
.
abort
(
f
"
{
request_id
}
-
{
prompt_idx
}
"
)
raise
StopAsyncIteration
()
for
output
in
res
.
outputs
:
...
...
@@ -430,12 +426,17 @@ class OpenAIServingCompletion(OpenAIServing):
step_top_logprobs
=
top_logprobs
[
i
]
if
step_top_logprobs
is
None
:
token
=
tokenizer
.
decode
(
token_id
)
if
self
.
return_tokens_as_token_ids
:
token
=
f
"token_id:
{
token_id
}
"
out_tokens
.
append
(
token
)
out_token_logprobs
.
append
(
None
)
out_top_logprobs
.
append
(
None
)
else
:
token
=
self
.
_get_decoded_token
(
step_top_logprobs
[
token_id
],
token_id
,
tokenizer
)
token
=
self
.
_get_decoded_token
(
step_top_logprobs
[
token_id
],
token_id
,
tokenizer
,
return_as_token_id
=
self
.
return_tokens_as_token_ids
)
token_logprob
=
max
(
step_top_logprobs
[
token_id
].
logprob
,
-
9999.0
)
out_tokens
.
append
(
token
)
...
...
@@ -448,7 +449,11 @@ class OpenAIServingCompletion(OpenAIServing):
out_top_logprobs
.
append
({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self
.
_get_decoded_token
(
top_lp
[
1
],
top_lp
[
0
],
tokenizer
):
self
.
_get_decoded_token
(
top_lp
[
1
],
top_lp
[
0
],
tokenizer
,
return_as_token_id
=
self
.
return_tokens_as_token_ids
):
max
(
top_lp
[
1
].
logprob
,
-
9999.0
)
for
i
,
top_lp
in
enumerate
(
step_top_logprobs
.
items
())
if
num_output_top_logprobs
>=
i
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
e661d594
...
...
@@ -6,7 +6,7 @@ import numpy as np
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.engine.
async_llm_engine
import
Async
LLM
Engine
from
vllm.engine.
protocol
import
AsyncEngine
Client
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
EmbeddingRequest
,
EmbeddingResponse
,
...
...
@@ -56,13 +56,13 @@ class OpenAIServingEmbedding(OpenAIServing):
def
__init__
(
self
,
engine
:
Async
LLM
Engine
,
async_engine_client
:
AsyncEngine
Client
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
*
,
request_logger
:
Optional
[
RequestLogger
],
):
super
().
__init__
(
engine
=
engine
,
super
().
__init__
(
async_engine_client
=
async_engine_client
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
lora_modules
=
None
,
...
...
@@ -99,7 +99,8 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
pooling_params
=
request
.
to_pooling_params
()
...
...
@@ -124,7 +125,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"Prompt adapter is not supported "
"for embedding models"
)
generator
=
self
.
engine
.
encode
(
generator
=
self
.
async_engine_client
.
encode
(
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]},
pooling_params
,
request_id_item
,
...
...
@@ -146,7 +147,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async
for
i
,
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
await
self
.
async_engine_client
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
return
self
.
create_error_response
(
"Client disconnected"
)
final_res_batch
[
i
]
=
res
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
e661d594
...
...
@@ -5,11 +5,10 @@ from http import HTTPStatus
from
typing
import
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
,
TypedDict
,
Union
from
pydantic
import
Field
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
typing_extensions
import
Annotated
from
vllm.config
import
ModelConfig
from
vllm.engine.
async_llm_engine
import
Async
LLM
Engine
from
vllm.engine.
protocol
import
AsyncEngine
Client
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
...
...
@@ -26,10 +25,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from
vllm.inputs
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer_group
import
AnyTokenizer
logger
=
init_logger
(
__name__
)
...
...
@@ -49,8 +51,6 @@ class LoRAModulePath:
AnyRequest
=
Union
[
ChatCompletionRequest
,
CompletionRequest
,
DetokenizeRequest
,
EmbeddingRequest
,
TokenizeRequest
]
AnyTokenizer
=
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
class
TextTokensPrompt
(
TypedDict
):
prompt
:
str
...
...
@@ -61,17 +61,18 @@ class OpenAIServing:
def
__init__
(
self
,
engine
:
Async
LLM
Engine
,
async_engine_client
:
AsyncEngine
Client
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
request_logger
:
Optional
[
RequestLogger
],
return_tokens_as_token_ids
:
bool
=
False
,
):
super
().
__init__
()
self
.
engine
=
engine
self
.
async_engine_client
=
async_engine_client
self
.
model_config
=
model_config
self
.
max_model_len
=
model_config
.
max_model_len
...
...
@@ -102,6 +103,7 @@ class OpenAIServing:
prompt_adapter_num_virtual_tokens
=
num_virtual_tokens
))
self
.
request_logger
=
request_logger
self
.
return_tokens_as_token_ids
=
return_tokens_as_token_ids
async
def
show_available_models
(
self
)
->
ModelList
:
"""Show available models. Right now we only have one model."""
...
...
@@ -150,6 +152,15 @@ class OpenAIServing:
})
return
json_str
async
def
_guided_decode_logits_processor
(
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
tokenizer
:
AnyTokenizer
)
->
Optional
[
LogitsProcessor
]:
decoding_config
=
await
self
.
async_engine_client
.
get_decoding_config
()
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
return
await
get_guided_decoding_logits_processor
(
guided_decoding_backend
,
request
,
tokenizer
)
async
def
_check_model
(
self
,
request
:
AnyRequest
,
...
...
@@ -254,9 +265,7 @@ class OpenAIServing:
f
"
{
self
.
max_model_len
}
tokens. However, you requested "
f
"
{
token_num
}
tokens in the messages, "
f
"Please reduce the length of the messages."
)
request
.
max_tokens
=
self
.
max_model_len
-
token_num
if
token_num
+
request
.
max_tokens
>
self
.
max_model_len
:
elif
token_num
+
request
.
max_tokens
>
self
.
max_model_len
:
raise
ValueError
(
f
"This model's maximum context length is "
f
"
{
self
.
max_model_len
}
tokens. However, you requested "
...
...
@@ -384,11 +393,13 @@ class OpenAIServing:
)
@
staticmethod
def
_get_decoded_token
(
logprob
:
Logprob
,
token_id
:
int
,
tokenizer
:
AnyTokenizer
,
)
->
str
:
def
_get_decoded_token
(
logprob
:
Logprob
,
token_id
:
int
,
tokenizer
:
AnyTokenizer
,
return_as_token_id
:
bool
=
False
)
->
str
:
if
return_as_token_id
:
return
f
"token_id:
{
token_id
}
"
if
logprob
.
decoded_token
is
not
None
:
return
logprob
.
decoded_token
return
tokenizer
.
decode
(
token_id
)
vllm/entrypoints/openai/serving_tokenization.py
View file @
e661d594
from
typing
import
List
,
Optional
,
Union
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.chat_utils
import
load_chat_template
,
parse_chat_messages
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
load_chat_template
,
parse_chat_message_content
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
DetokenizeRequest
,
DetokenizeResponse
,
ErrorResponse
,
...
...
@@ -17,14 +15,17 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
# yapf: enable
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
)
from
vllm.logger
import
init_logger
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
class
OpenAIServingTokenization
(
OpenAIServing
):
def
__init__
(
self
,
engine
:
Async
LLM
Engine
,
async_engine_client
:
AsyncEngine
Client
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
*
,
...
...
@@ -32,7 +33,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_logger
:
Optional
[
RequestLogger
],
chat_template
:
Optional
[
str
],
):
super
().
__init__
(
engine
=
engine
,
super
().
__init__
(
async_engine_client
=
async_engine_client
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
lora_modules
=
lora_modules
,
...
...
@@ -57,17 +58,17 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
if
isinstance
(
request
,
TokenizeChatRequest
):
model_config
=
self
.
model_config
conversation
:
List
[
ConversationMessage
]
=
[]
conversation
,
mm_futures
=
parse_chat_messages
(
request
.
messages
,
model_config
,
tokenizer
)
for
message
in
request
.
messages
:
result
=
parse_chat_message_content
(
message
,
model_config
,
tokenizer
)
conversation
.
extend
(
result
.
messages
)
if
mm_futures
:
logger
.
warning
(
"Multi-modal inputs are ignored during tokenization"
)
prompt
=
tokenizer
.
apply_chat_template
(
add_generation_prompt
=
request
.
add_generation_prompt
,
...
...
@@ -113,7 +114,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
self
.
_log_inputs
(
request_id
,
request
.
tokens
,
...
...
vllm/envs.py
View file @
e661d594
...
...
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
if
TYPE_CHECKING
:
VLLM_HOST_IP
:
str
=
""
VLLM_PORT
:
Optional
[
int
]
=
None
VLLM_RPC_PORT
:
int
=
5570
VLLM_USE_MODELSCOPE
:
bool
=
False
VLLM_RINGBUFFER_WARNING_INTERVAL
:
int
=
60
VLLM_INSTANCE_ID
:
Optional
[
str
]
=
None
...
...
@@ -28,7 +29,9 @@ if TYPE_CHECKING:
VLLM_LOGGING_CONFIG_PATH
:
Optional
[
str
]
=
None
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_CPU_OMP_THREADS_BIND
:
str
=
""
VLLM_OPENVINO_KVCACHE_SPACE
:
int
=
0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION
:
Optional
[
str
]
=
None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS
:
bool
=
False
...
...
@@ -36,6 +39,7 @@ if TYPE_CHECKING:
VLLM_FUSED_MOE_CHUNK_SIZE
:
int
=
64
*
1024
VLLM_USE_RAY_SPMD_WORKER
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL
:
bool
=
True
VLLM_WORKER_MULTIPROC_METHOD
:
str
=
"fork"
VLLM_ASSETS_CACHE
:
str
=
os
.
path
.
join
(
VLLM_CACHE_ROOT
,
"assets"
)
VLLM_IMAGE_FETCH_TIMEOUT
:
int
=
5
...
...
@@ -43,10 +47,10 @@ if TYPE_CHECKING:
MAX_JOBS
:
Optional
[
str
]
=
None
NVCC_THREADS
:
Optional
[
str
]
=
None
VLLM_USE_PRECOMPILED
:
bool
=
False
VLLM_INSTALL_PUNICA_KERNELS
:
bool
=
False
VLLM_NO_DEPRECATION_WARNING
:
bool
=
False
CMAKE_BUILD_TYPE
:
Optional
[
str
]
=
None
VERBOSE
:
bool
=
False
VLLM_ALLOW_LONG_MAX_MODEL_LEN
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -92,10 +96,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_USE_PRECOMPILED"
:
lambda
:
bool
(
os
.
environ
.
get
(
"VLLM_USE_PRECOMPILED"
)),
# If set, vllm will install Punica kernels
"VLLM_INSTALL_PUNICA_KERNELS"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_INSTALL_PUNICA_KERNELS"
,
"0"
))),
# CMake build type
# If not set, defaults to "Debug" or "RelWithDebInfo"
# Available options: "Debug", "Release", "RelWithDebInfo"
...
...
@@ -142,6 +142,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
int
(
os
.
getenv
(
'VLLM_PORT'
,
'0'
))
if
'VLLM_PORT'
in
os
.
environ
else
None
,
# used when the frontend api server is running in multi-processing mode,
# to communicate with the backend engine process over ZMQ.
'VLLM_RPC_PORT'
:
lambda
:
int
(
os
.
getenv
(
'VLLM_PORT'
,
'5570'
)),
# If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers
"VLLM_USE_MODELSCOPE"
:
...
...
@@ -181,6 +186,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
,
"0"
)),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK"
:
...
...
@@ -246,11 +255,20 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_ATTENTION_BACKEND"
:
lambda
:
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
,
None
),
# CPU key-value cache space
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION"
:
lambda
:
os
.
getenv
(
"VLLM_PP_LAYER_PARTITION"
,
None
),
# (CPU backend only) CPU key-value cache space.
# default is 4GB
"VLLM_CPU_KVCACHE_SPACE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_CPU_KVCACHE_SPACE"
,
"0"
)),
# (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31",
# "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'.
"VLLM_CPU_OMP_THREADS_BIND"
:
lambda
:
os
.
getenv
(
"VLLM_CPU_OMP_THREADS_BIND"
,
"all"
),
# OpenVINO key-value cache space
# default is 4GB
"VLLM_OPENVINO_KVCACHE_SPACE"
:
...
...
@@ -272,13 +290,20 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# execution on all workers.
# Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it.
"VLLM_USE_RAY_SPMD_WORKER"
:
lambda
:
bool
(
os
.
getenv
(
"VLLM_USE_RAY_SPMD_WORKER"
,
0
)),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_RAY_SPMD_WORKER"
,
"0"
)
)),
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
"VLLM_USE_RAY_COMPILED_DAG"
:
lambda
:
bool
(
os
.
getenv
(
"VLLM_USE_RAY_COMPILED_DAG"
,
0
)),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_RAY_COMPILED_DAG"
,
"0"
))),
# If the env var is set, it uses NCCL for communication in
# Ray's compiled DAG. This flag is ignored if
# VLLM_USE_RAY_COMPILED_DAG is not set.
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"
,
"1"
))
),
# Use dedicated multiprocess context for workers.
# Both spawn and fork work
...
...
@@ -312,6 +337,15 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, vllm will skip the deprecation warnings.
"VLLM_NO_DEPRECATION_WARNING"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_NO_DEPRECATION_WARNING"
,
"0"
))),
# If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows
# the user to specify a max sequence length greater than
# the max length derived from the model's config.json.
# To enable this, set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1.
"VLLM_ALLOW_LONG_MAX_MODEL_LEN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_ALLOW_LONG_MAX_MODEL_LEN"
,
"0"
).
strip
().
lower
()
in
(
"1"
,
"true"
)),
}
# end-env-vars-definition
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment