Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
500b93c8
Commit
500b93c8
authored
Jul 25, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1
parents
99426767
38c4b7e8
Changes
282
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1340 additions
and
640 deletions
+1340
-640
vllm/engine/metrics.py
vllm/engine/metrics.py
+150
-30
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+5
-1
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+190
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+19
-24
vllm/entrypoints/logger.py
vllm/entrypoints/logger.py
+41
-0
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+53
-18
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+8
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+60
-84
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+19
-1
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+119
-220
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+83
-100
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+54
-25
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+216
-72
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+135
-0
vllm/envs.py
vllm/envs.py
+56
-6
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+2
-0
vllm/executor/distributed_gpu_executor.py
vllm/executor/distributed_gpu_executor.py
+5
-3
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+2
-21
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+31
-15
vllm/executor/multiproc_gpu_executor.py
vllm/executor/multiproc_gpu_executor.py
+92
-20
No files found.
vllm/engine/metrics.py
View file @
500b93c8
...
...
@@ -30,55 +30,55 @@ prometheus_client.disable_created_metrics()
# begin-metrics-definitions
class
Metrics
:
labelname_finish_reason
=
"finished_reason"
_base_library
=
prometheus_client
_gauge_cls
=
prometheus_client
.
Gauge
_counter_cls
=
prometheus_client
.
Counter
_histogram_cls
=
prometheus_client
.
Histogram
def
__init__
(
self
,
labelnames
:
List
[
str
],
max_model_len
:
int
):
# Unregister any existing vLLM collectors
self
.
_unregister_vllm_metrics
()
# Config Information
self
.
info_cache_config
=
prometheus_client
.
Info
(
name
=
'vllm:cache_config'
,
documentation
=
'information of cache_config'
)
self
.
_create_info_cache_config
()
# System stats
# Scheduler State
self
.
gauge_scheduler_running
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_scheduler_running
=
self
.
_
gauge_cls
(
name
=
"vllm:num_requests_running"
,
documentation
=
"Number of requests currently running on GPU."
,
labelnames
=
labelnames
)
self
.
gauge_scheduler_waiting
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_scheduler_waiting
=
self
.
_
gauge_cls
(
name
=
"vllm:num_requests_waiting"
,
documentation
=
"Number of requests waiting to be processed."
,
labelnames
=
labelnames
)
self
.
gauge_scheduler_swapped
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_scheduler_swapped
=
self
.
_
gauge_cls
(
name
=
"vllm:num_requests_swapped"
,
documentation
=
"Number of requests swapped to CPU."
,
labelnames
=
labelnames
)
# KV Cache Usage in %
self
.
gauge_gpu_cache_usage
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_gpu_cache_usage
=
self
.
_
gauge_cls
(
name
=
"vllm:gpu_cache_usage_perc"
,
documentation
=
"GPU KV-cache usage. 1 means 100 percent usage."
,
labelnames
=
labelnames
)
self
.
gauge_cpu_cache_usage
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_cpu_cache_usage
=
self
.
_
gauge_cls
(
name
=
"vllm:cpu_cache_usage_perc"
,
documentation
=
"CPU KV-cache usage. 1 means 100 percent usage."
,
labelnames
=
labelnames
)
# Iteration stats
self
.
counter_num_preemption
=
self
.
_
base_library
.
C
ounter
(
self
.
counter_num_preemption
=
self
.
_
c
ounter
_cls
(
name
=
"vllm:num_preemptions_total"
,
documentation
=
"Cumulative number of preemption from the engine."
,
labelnames
=
labelnames
)
self
.
counter_prompt_tokens
=
self
.
_
base_library
.
C
ounter
(
self
.
counter_prompt_tokens
=
self
.
_
c
ounter
_cls
(
name
=
"vllm:prompt_tokens_total"
,
documentation
=
"Number of prefill tokens processed."
,
labelnames
=
labelnames
)
self
.
counter_generation_tokens
=
self
.
_
base_library
.
C
ounter
(
self
.
counter_generation_tokens
=
self
.
_
c
ounter
_cls
(
name
=
"vllm:generation_tokens_total"
,
documentation
=
"Number of generation tokens processed."
,
labelnames
=
labelnames
)
self
.
histogram_time_to_first_token
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_time_to_first_token
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:time_to_first_token_seconds"
,
documentation
=
"Histogram of time to first token in seconds."
,
labelnames
=
labelnames
,
...
...
@@ -86,7 +86,7 @@ class Metrics:
0.001
,
0.005
,
0.01
,
0.02
,
0.04
,
0.06
,
0.08
,
0.1
,
0.25
,
0.5
,
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
])
self
.
histogram_time_per_output_token
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_time_per_output_token
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:time_per_output_token_seconds"
,
documentation
=
"Histogram of time per output token in seconds."
,
labelnames
=
labelnames
,
...
...
@@ -97,59 +97,157 @@ class Metrics:
# Request stats
# Latency
self
.
histogram_e2e_time_request
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_e2e_time_request
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:e2e_request_latency_seconds"
,
documentation
=
"Histogram of end to end request latency in seconds."
,
labelnames
=
labelnames
,
buckets
=
[
1.0
,
2.5
,
5.0
,
10.0
,
15.0
,
20.0
,
30.0
,
40.0
,
50.0
,
60.0
])
# Metadata
self
.
histogram_num_prompt_tokens_request
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_num_prompt_tokens_request
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:request_prompt_tokens"
,
documentation
=
"Number of prefill tokens processed."
,
labelnames
=
labelnames
,
buckets
=
build_1_2_5_buckets
(
max_model_len
),
)
self
.
histogram_num_generation_tokens_request
=
\
self
.
_
base_library
.
H
istogram
(
self
.
_
h
istogram
_cls
(
name
=
"vllm:request_generation_tokens"
,
documentation
=
"Number of generation tokens processed."
,
labelnames
=
labelnames
,
buckets
=
build_1_2_5_buckets
(
max_model_len
),
)
self
.
histogram_best_of_request
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_best_of_request
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:request_params_best_of"
,
documentation
=
"Histogram of the best_of request parameter."
,
labelnames
=
labelnames
,
buckets
=
[
1
,
2
,
5
,
10
,
20
],
)
self
.
histogram_n_request
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_n_request
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:request_params_n"
,
documentation
=
"Histogram of the n request parameter."
,
labelnames
=
labelnames
,
buckets
=
[
1
,
2
,
5
,
10
,
20
],
)
self
.
counter_request_success
=
self
.
_
base_library
.
C
ounter
(
self
.
counter_request_success
=
self
.
_
c
ounter
_cls
(
name
=
"vllm:request_success_total"
,
documentation
=
"Count of successfully processed requests."
,
labelnames
=
labelnames
+
[
Metrics
.
labelname_finish_reason
])
# Speculatie decoding stats
self
.
gauge_spec_decode_draft_acceptance_rate
=
self
.
_gauge_cls
(
name
=
"vllm:spec_decode_draft_acceptance_rate"
,
documentation
=
"Speulative token acceptance rate."
,
labelnames
=
labelnames
)
self
.
gauge_spec_decode_efficiency
=
self
.
_gauge_cls
(
name
=
"vllm:spec_decode_efficiency"
,
documentation
=
"Speculative decoding system efficiency."
,
labelnames
=
labelnames
)
self
.
counter_spec_decode_num_accepted_tokens
=
(
self
.
_counter_cls
(
name
=
"vllm:spec_decode_num_accepted_tokens_total"
,
documentation
=
"Number of accepted tokens."
,
labelnames
=
labelnames
))
self
.
counter_spec_decode_num_draft_tokens
=
self
.
_counter_cls
(
name
=
"vllm:spec_decode_num_draft_tokens_total"
,
documentation
=
"Number of draft tokens."
,
labelnames
=
labelnames
)
self
.
counter_spec_decode_num_emitted_tokens
=
(
self
.
_counter_cls
(
name
=
"vllm:spec_decode_num_emitted_tokens_total"
,
documentation
=
"Number of emitted tokens."
,
labelnames
=
labelnames
))
# Deprecated in favor of vllm:prompt_tokens_total
self
.
gauge_avg_prompt_throughput
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_avg_prompt_throughput
=
self
.
_
gauge_cls
(
name
=
"vllm:avg_prompt_throughput_toks_per_s"
,
documentation
=
"Average prefill throughput in tokens/s."
,
labelnames
=
labelnames
,
)
# Deprecated in favor of vllm:generation_tokens_total
self
.
gauge_avg_generation_throughput
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_avg_generation_throughput
=
self
.
_
gauge_cls
(
name
=
"vllm:avg_generation_throughput_toks_per_s"
,
documentation
=
"Average generation throughput in tokens/s."
,
labelnames
=
labelnames
,
)
def
_create_info_cache_config
(
self
)
->
None
:
# Config Information
self
.
info_cache_config
=
prometheus_client
.
Info
(
name
=
'vllm:cache_config'
,
documentation
=
'information of cache_config'
)
def
_unregister_vllm_metrics
(
self
)
->
None
:
for
collector
in
list
(
self
.
_base_library
.
REGISTRY
.
_collector_to_names
):
for
collector
in
list
(
prometheus_client
.
REGISTRY
.
_collector_to_names
):
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
self
.
_base_library
.
REGISTRY
.
unregister
(
collector
)
prometheus_client
.
REGISTRY
.
unregister
(
collector
)
# end-metrics-definitions
class
_RayGaugeWrapper
:
"""Wraps around ray.util.metrics.Gauge to provide same API as
prometheus_client.Gauge"""
def
__init__
(
self
,
name
:
str
,
documentation
:
str
=
""
,
labelnames
:
Optional
[
List
[
str
]]
=
None
):
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
self
.
_gauge
=
ray_metrics
.
Gauge
(
name
=
name
,
description
=
documentation
,
tag_keys
=
labelnames_tuple
)
def
labels
(
self
,
**
labels
):
self
.
_gauge
.
set_default_tags
(
labels
)
return
self
def
set
(
self
,
value
:
Union
[
int
,
float
]):
return
self
.
_gauge
.
set
(
value
)
class
_RayCounterWrapper
:
"""Wraps around ray.util.metrics.Counter to provide same API as
prometheus_client.Counter"""
def
__init__
(
self
,
name
:
str
,
documentation
:
str
=
""
,
labelnames
:
Optional
[
List
[
str
]]
=
None
):
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
self
.
_counter
=
ray_metrics
.
Counter
(
name
=
name
,
description
=
documentation
,
tag_keys
=
labelnames_tuple
)
def
labels
(
self
,
**
labels
):
self
.
_counter
.
set_default_tags
(
labels
)
return
self
def
inc
(
self
,
value
:
Union
[
int
,
float
]
=
1.0
):
if
value
==
0
:
return
return
self
.
_counter
.
inc
(
value
)
class
_RayHistogramWrapper
:
"""Wraps around ray.util.metrics.Histogram to provide same API as
prometheus_client.Histogram"""
def
__init__
(
self
,
name
:
str
,
documentation
:
str
=
""
,
labelnames
:
Optional
[
List
[
str
]]
=
None
,
buckets
:
Optional
[
List
[
float
]]
=
None
):
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
self
.
_histogram
=
ray_metrics
.
Histogram
(
name
=
name
,
description
=
documentation
,
tag_keys
=
labelnames_tuple
,
boundaries
=
buckets
)
def
labels
(
self
,
**
labels
):
self
.
_histogram
.
set_default_tags
(
labels
)
return
self
def
observe
(
self
,
value
:
Union
[
int
,
float
]):
return
self
.
_histogram
.
observe
(
value
)
class
RayMetrics
(
Metrics
):
...
...
@@ -157,7 +255,9 @@ class RayMetrics(Metrics):
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_base_library
=
ray_metrics
_gauge_cls
=
_RayGaugeWrapper
_counter_cls
=
_RayCounterWrapper
_histogram_cls
=
_RayHistogramWrapper
def
__init__
(
self
,
labelnames
:
List
[
str
],
max_model_len
:
int
):
if
ray_metrics
is
None
:
...
...
@@ -168,8 +268,9 @@ class RayMetrics(Metrics):
# No-op on purpose
pass
# end-metrics-definitions
def
_create_info_cache_config
(
self
)
->
None
:
# No-op on purpose
pass
def
build_1_2_5_buckets
(
max_value
:
int
)
->
List
[
int
]:
...
...
@@ -325,8 +426,8 @@ class LoggingStatLogger(StatLoggerBase):
f
"System efficiency:
{
metrics
.
system_efficiency
:.
3
f
}
, "
f
"Number of speculative tokens:
{
metrics
.
num_spec_tokens
}
, "
f
"Number of accepted tokens:
{
metrics
.
accepted_tokens
}
, "
f
"Number of draft
tokens
tokens:
{
metrics
.
draft_tokens
}
, "
f
"Number of emitted
tokens
tokens:
{
metrics
.
emitted_tokens
}
."
)
f
"Number of draft tokens:
{
metrics
.
draft_tokens
}
, "
f
"Number of emitted tokens:
{
metrics
.
emitted_tokens
}
."
)
class
PrometheusStatLogger
(
StatLoggerBase
):
...
...
@@ -454,7 +555,26 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
num_generation_tokens
=
[]
self
.
last_local_log
=
stats
.
now
if
stats
.
spec_decode_metrics
is
not
None
:
self
.
_log_gauge
(
self
.
metrics
.
gauge_spec_decode_draft_acceptance_rate
,
stats
.
spec_decode_metrics
.
draft_acceptance_rate
)
self
.
_log_gauge
(
self
.
metrics
.
gauge_spec_decode_efficiency
,
stats
.
spec_decode_metrics
.
system_efficiency
)
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_accepted_tokens
,
stats
.
spec_decode_metrics
.
accepted_tokens
)
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_draft_tokens
,
stats
.
spec_decode_metrics
.
draft_tokens
)
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_emitted_tokens
,
stats
.
spec_decode_metrics
.
emitted_tokens
)
class
RayPrometheusStatLogger
(
PrometheusStatLogger
):
"""RayPrometheusStatLogger uses Ray metrics instead."""
_metrics_cls
=
RayMetrics
\ No newline at end of file
_metrics_cls
=
RayMetrics
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
return
None
vllm/engine/output_processor/single_step.py
View file @
500b93c8
...
...
@@ -90,7 +90,11 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
for
parent_seq
in
parent_seqs
}
for
sample
in
samples
:
parent_child_dict
[
sample
.
parent_seq_id
].
append
(
sample
)
# Guard against a KeyError which can occur if the request was
# aborted while the output was generated
if
(
child_list
:
=
parent_child_dict
.
get
(
sample
.
parent_seq_id
))
is
not
None
:
child_list
.
append
(
sample
)
# List of (child, parent)
child_seqs
:
List
[
Tuple
[
Sequence
,
Sequence
]]
=
[]
...
...
vllm/entrypoints/chat_utils.py
0 → 100644
View file @
500b93c8
import
codecs
from
dataclasses
import
dataclass
,
field
from
functools
import
lru_cache
from
typing
import
Awaitable
,
Iterable
,
List
,
Optional
,
Union
,
cast
,
final
# yapf conflicts with isort for this block
# yapf: disable
from
openai.types.chat
import
ChatCompletionContentPartImageParam
from
openai.types.chat
import
(
ChatCompletionContentPartParam
as
OpenAIChatCompletionContentPartParam
)
from
openai.types.chat
import
ChatCompletionContentPartTextParam
from
openai.types.chat
import
(
ChatCompletionMessageParam
as
OpenAIChatCompletionMessageParam
)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from
pydantic
import
ConfigDict
from
transformers
import
PreTrainedTokenizer
from
typing_extensions
import
Required
,
TypedDict
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal.utils
import
async_get_and_parse_image
logger
=
init_logger
(
__name__
)
class
CustomChatCompletionContentPartParam
(
TypedDict
,
total
=
False
):
__pydantic_config__
=
ConfigDict
(
extra
=
"allow"
)
# type: ignore
type
:
Required
[
str
]
"""The type of the content part."""
ChatCompletionContentPartParam
=
Union
[
OpenAIChatCompletionContentPartParam
,
CustomChatCompletionContentPartParam
]
class
CustomChatCompletionMessageParam
(
TypedDict
,
total
=
False
):
"""Enables custom roles in the Chat Completion API."""
role
:
Required
[
str
]
"""The role of the message's author."""
content
:
Union
[
str
,
List
[
ChatCompletionContentPartParam
]]
"""The contents of the message."""
name
:
str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
ChatCompletionMessageParam
=
Union
[
OpenAIChatCompletionMessageParam
,
CustomChatCompletionMessageParam
]
@
final
# So that it should be compatible with Dict[str, str]
class
ConversationMessage
(
TypedDict
):
role
:
str
content
:
str
@
dataclass
(
frozen
=
True
)
class
ChatMessageParseResult
:
messages
:
List
[
ConversationMessage
]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
field
(
default_factory
=
list
)
def
load_chat_template
(
chat_template
:
Optional
[
str
])
->
Optional
[
str
]:
if
chat_template
is
None
:
return
None
try
:
with
open
(
chat_template
,
"r"
)
as
f
:
resolved_chat_template
=
f
.
read
()
except
OSError
as
e
:
JINJA_CHARS
=
"{}
\n
"
if
not
any
(
c
in
chat_template
for
c
in
JINJA_CHARS
):
msg
=
(
f
"The supplied chat template (
{
chat_template
}
) "
f
"looks like a file path, but it failed to be "
f
"opened. Reason:
{
e
}
"
)
raise
ValueError
(
msg
)
from
e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
resolved_chat_template
=
codecs
.
decode
(
chat_template
,
"unicode_escape"
)
logger
.
info
(
"Using supplied chat template:
\n
%s"
,
resolved_chat_template
)
return
resolved_chat_template
@
lru_cache
(
maxsize
=
None
)
def
_image_token_str
(
model_config
:
ModelConfig
,
tokenizer
:
PreTrainedTokenizer
)
->
Optional
[
str
]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type
=
model_config
.
hf_config
.
model_type
if
model_type
==
"phi3_v"
:
# Workaround since this token is not defined in the tokenizer
return
"<|image_1|>"
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"minicpmv"
,
"paligemma"
):
# These models do not use image tokens in the prompt
return
None
if
model_type
.
startswith
(
"llava"
):
return
tokenizer
.
decode
(
model_config
.
hf_config
.
image_token_index
)
if
model_type
==
"chameleon"
:
return
"<image>"
raise
TypeError
(
"Unknown model type: {model_type}"
)
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
def
_get_full_image_text_prompt
(
image_token_str
:
str
,
text_prompt
:
str
)
->
str
:
"""Combine image and text prompts for vision language model"""
# NOTE: For now we assume all model architectures use the same
# image + text prompt format. This may change in the future.
return
f
"
{
image_token_str
}
\n
{
text_prompt
}
"
def
_parse_chat_message_content_parts
(
role
:
str
,
parts
:
Iterable
[
ChatCompletionContentPartParam
],
model_config
:
ModelConfig
,
tokenizer
:
PreTrainedTokenizer
,
)
->
ChatMessageParseResult
:
texts
:
List
[
str
]
=
[]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
[]
for
part
in
parts
:
part_type
=
part
[
"type"
]
if
part_type
==
"text"
:
text
=
cast
(
ChatCompletionContentPartTextParam
,
part
)[
"text"
]
texts
.
append
(
text
)
elif
part_type
==
"image_url"
:
if
len
(
mm_futures
)
>
0
:
raise
NotImplementedError
(
"Multiple 'image_url' input is currently not supported."
)
image_url
=
cast
(
ChatCompletionContentPartImageParam
,
part
)[
"image_url"
]
if
image_url
.
get
(
"detail"
,
"auto"
)
!=
"auto"
:
logger
.
warning
(
"'image_url.detail' is currently not supported and "
"will be ignored."
)
image_future
=
async_get_and_parse_image
(
image_url
[
"url"
])
mm_futures
.
append
(
image_future
)
else
:
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
text_prompt
=
"
\n
"
.
join
(
texts
)
if
mm_futures
:
image_token_str
=
_image_token_str
(
model_config
,
tokenizer
)
if
image_token_str
is
not
None
:
if
image_token_str
in
text_prompt
:
logger
.
warning
(
"Detected image token string in the text prompt. "
"Skipping prompt formatting."
)
else
:
text_prompt
=
_get_full_image_text_prompt
(
image_token_str
=
image_token_str
,
text_prompt
=
text_prompt
,
)
messages
=
[
ConversationMessage
(
role
=
role
,
content
=
text_prompt
)]
return
ChatMessageParseResult
(
messages
=
messages
,
mm_futures
=
mm_futures
)
def
parse_chat_message_content
(
message
:
ChatCompletionMessageParam
,
model_config
:
ModelConfig
,
tokenizer
:
PreTrainedTokenizer
,
)
->
ChatMessageParseResult
:
role
=
message
[
"role"
]
content
=
message
.
get
(
"content"
)
if
content
is
None
:
return
ChatMessageParseResult
(
messages
=
[],
mm_futures
=
[])
if
isinstance
(
content
,
str
):
messages
=
[
ConversationMessage
(
role
=
role
,
content
=
content
)]
return
ChatMessageParseResult
(
messages
=
messages
,
mm_futures
=
[])
return
_parse_chat_message_content_parts
(
role
,
content
,
model_config
,
tokenizer
)
vllm/entrypoints/llm.py
View file @
500b93c8
...
...
@@ -6,8 +6,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.inputs
import
(
PromptInputs
,
PromptStrictInputs
,
TextPrompt
,
TextTokensPrompt
,
TokensPrompt
,
from
vllm.inputs
import
(
PromptInputs
,
TextPrompt
,
TokensPrompt
,
parse_and_batch_prompt
)
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
...
...
@@ -69,6 +68,10 @@ class LLM:
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
the model weights. This virtually increases the GPU memory space
you can use to hold the model weights, at the cost of CPU-GPU data
transfer for every forward pass.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
...
...
@@ -114,6 +117,7 @@ class LLM:
seed
:
int
=
0
,
gpu_memory_utilization
:
float
=
0.9
,
swap_space
:
int
=
4
,
cpu_offload_gb
:
float
=
0
,
enforce_eager
:
bool
=
False
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
int
=
8192
,
...
...
@@ -141,6 +145,7 @@ class LLM:
seed
=
seed
,
gpu_memory_utilization
=
gpu_memory_utilization
,
swap_space
=
swap_space
,
cpu_offload_gb
=
cpu_offload_gb
,
enforce_eager
=
enforce_eager
,
max_context_len_to_capture
=
max_context_len_to_capture
,
max_seq_len_to_capture
=
max_seq_len_to_capture
,
...
...
@@ -232,7 +237,7 @@ class LLM:
@
overload
def
generate
(
self
,
inputs
:
Union
[
Prompt
Strict
Inputs
,
Sequence
[
Prompt
Strict
Inputs
]],
inputs
:
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
/
,
# We may enable `inputs` keyword after removing the old API
*
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
...
...
@@ -249,7 +254,7 @@ class LLM:
"instead."
)
def
generate
(
self
,
prompts
:
Union
[
Union
[
Prompt
Strict
Inputs
,
Sequence
[
Prompt
Strict
Inputs
]],
prompts
:
Union
[
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
Sequence
[
SamplingParams
]]]
=
None
,
...
...
@@ -296,9 +301,7 @@ class LLM:
prompt_token_ids
=
prompt_token_ids
,
)
else
:
inputs
=
cast
(
Union
[
PromptStrictInputs
,
Sequence
[
PromptStrictInputs
]],
prompts
)
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
if
sampling_params
is
None
:
# Use default sampling params.
...
...
@@ -377,7 +380,7 @@ class LLM:
@
overload
def
encode
(
self
,
inputs
:
Union
[
Prompt
Strict
Inputs
,
Sequence
[
Prompt
Strict
Inputs
]],
inputs
:
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
/
,
# We may enable `inputs` keyword after removing the old API
*
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
...
...
@@ -394,7 +397,7 @@ class LLM:
"instead."
)
def
encode
(
self
,
prompts
:
Union
[
Union
[
Prompt
Strict
Inputs
,
Sequence
[
Prompt
Strict
Inputs
]],
prompts
:
Union
[
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
...
...
@@ -411,7 +414,7 @@ class LLM:
Args:
inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.Prompt
Strict
Inputs`
batch inference. See :class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
...
...
@@ -440,9 +443,7 @@ class LLM:
prompt_token_ids
=
prompt_token_ids
,
)
else
:
inputs
=
cast
(
Union
[
PromptStrictInputs
,
Sequence
[
PromptStrictInputs
]],
prompts
)
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
if
pooling_params
is
None
:
# Use default pooling params.
...
...
@@ -490,17 +491,11 @@ class LLM:
inputs
:
List
[
PromptInputs
]
=
[]
for
i
in
range
(
num_requests
):
if
prompts
is
not
None
:
if
prompt_token_ids
is
not
None
:
item
=
TextTokensPrompt
(
prompt
=
prompts
[
i
],
prompt_token_ids
=
prompt_token_ids
[
i
])
else
:
item
=
TextPrompt
(
prompt
=
prompts
[
i
])
item
=
TextPrompt
(
prompt
=
prompts
[
i
])
elif
prompt_token_ids
is
not
None
:
item
=
TokensPrompt
(
prompt_token_ids
=
prompt_token_ids
[
i
])
else
:
if
prompt_token_ids
is
not
None
:
item
=
TokensPrompt
(
prompt_token_ids
=
prompt_token_ids
[
i
])
else
:
raise
AssertionError
raise
AssertionError
inputs
.
append
(
item
)
...
...
@@ -508,7 +503,7 @@ class LLM:
def
_validate_and_add_requests
(
self
,
inputs
:
Union
[
Prompt
Strict
Inputs
,
Sequence
[
Prompt
Strict
Inputs
]],
inputs
:
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
params
:
Union
[
SamplingParams
,
Sequence
[
SamplingParams
],
PoolingParams
,
Sequence
[
PoolingParams
]],
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
...
...
vllm/entrypoints/logger.py
0 → 100644
View file @
500b93c8
from
typing
import
List
,
Optional
,
Union
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
logger
=
init_logger
(
__name__
)
class
RequestLogger
:
def
__init__
(
self
,
*
,
max_log_len
:
Optional
[
int
])
->
None
:
super
().
__init__
()
self
.
max_log_len
=
max_log_len
def
log_inputs
(
self
,
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt_token_ids
:
Optional
[
List
[
int
]],
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]],
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
None
:
max_log_len
=
self
.
max_log_len
if
max_log_len
is
not
None
:
if
prompt
is
not
None
:
prompt
=
prompt
[:
max_log_len
]
if
prompt_token_ids
is
not
None
:
prompt_token_ids
=
prompt_token_ids
[:
max_log_len
]
logger
.
info
(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"lora_request: %s, prompt_adapter_request: %s."
,
request_id
,
prompt
,
params
,
prompt_token_ids
,
lora_request
,
prompt_adapter_request
)
vllm/entrypoints/openai/api_server.py
View file @
500b93c8
...
...
@@ -18,6 +18,7 @@ from starlette.routing import Mount
import
vllm.envs
as
envs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
# yapf conflicts with isort for this block
# yapf: disable
...
...
@@ -33,6 +34,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.entrypoints.openai.serving_tokenization
import
(
OpenAIServingTokenization
)
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
...
...
@@ -40,12 +43,12 @@ from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE
=
5
# seconds
logger
=
init_logger
(
__name__
)
engine
:
AsyncLLMEngine
engine_args
:
AsyncEngineArgs
openai_serving_chat
:
OpenAIServingChat
openai_serving_completion
:
OpenAIServingCompletion
openai_serving_embedding
:
OpenAIServingEmbedding
openai_serving_tokenization
:
OpenAIServingTokenization
logger
=
init_logger
(
'vllm.entrypoints.openai.api_server'
)
...
...
@@ -70,11 +73,13 @@ async def lifespan(app: fastapi.FastAPI):
router
=
APIRouter
()
# Add prometheus asgi middleware to route /metrics requests
route
=
Mount
(
"/metrics"
,
make_asgi_app
())
# Workaround for 307 Redirect for /metrics
route
.
path_regex
=
re
.
compile
(
'^/metrics(?P<path>.*)$'
)
router
.
routes
.
append
(
route
)
def
mount_metrics
(
app
:
fastapi
.
FastAPI
):
# Add prometheus asgi middleware to route /metrics requests
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
())
# Workaround for 307 Redirect for /metrics
metrics_route
.
path_regex
=
re
.
compile
(
'^/metrics(?P<path>.*)$'
)
app
.
routes
.
append
(
metrics_route
)
@
router
.
get
(
"/health"
)
...
...
@@ -86,7 +91,7 @@ async def health() -> Response:
@
router
.
post
(
"/tokenize"
)
async
def
tokenize
(
request
:
TokenizeRequest
):
generator
=
await
openai_serving_
comple
tion
.
create_tokenize
(
request
)
generator
=
await
openai_serving_
tokeniza
tion
.
create_tokenize
(
request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
...
...
@@ -97,7 +102,7 @@ async def tokenize(request: TokenizeRequest):
@
router
.
post
(
"/detokenize"
)
async
def
detokenize
(
request
:
DetokenizeRequest
):
generator
=
await
openai_serving_
comple
tion
.
create_detokenize
(
request
)
generator
=
await
openai_serving_
tokeniza
tion
.
create_detokenize
(
request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
...
...
@@ -164,6 +169,8 @@ def build_app(args):
app
.
include_router
(
router
)
app
.
root_path
=
args
.
root_path
mount_metrics
(
app
)
app
.
add_middleware
(
CORSMiddleware
,
allow_origins
=
args
.
allowed_origins
,
...
...
@@ -238,20 +245,48 @@ def run_server(args, llm_engine=None):
# When using single vLLM without engine_use_ray
model_config
=
asyncio
.
run
(
engine
.
get_model_config
())
if
args
.
disable_log_requests
:
request_logger
=
None
else
:
request_logger
=
RequestLogger
(
max_log_len
=
args
.
max_log_len
)
global
openai_serving_chat
global
openai_serving_completion
global
openai_serving_embedding
openai_serving_chat
=
OpenAIServingChat
(
engine
,
model_config
,
served_model_names
,
args
.
response_role
,
args
.
lora_modules
,
args
.
chat_template
)
global
openai_serving_tokenization
openai_serving_chat
=
OpenAIServingChat
(
engine
,
model_config
,
served_model_names
,
args
.
response_role
,
lora_modules
=
args
.
lora_modules
,
prompt_adapters
=
args
.
prompt_adapters
,
request_logger
=
request_logger
,
chat_template
=
args
.
chat_template
,
)
openai_serving_completion
=
OpenAIServingCompletion
(
engine
,
model_config
,
served_model_names
,
args
.
lora_modules
,
args
.
prompt_adapters
)
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine
,
model_config
,
served_model_names
)
engine
,
model_config
,
served_model_names
,
lora_modules
=
args
.
lora_modules
,
prompt_adapters
=
args
.
prompt_adapters
,
request_logger
=
request_logger
,
)
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine
,
model_config
,
served_model_names
,
request_logger
=
request_logger
,
)
openai_serving_tokenization
=
OpenAIServingTokenization
(
engine
,
model_config
,
served_model_names
,
lora_modules
=
args
.
lora_modules
,
request_logger
=
request_logger
,
chat_template
=
args
.
chat_template
,
)
app
.
root_path
=
args
.
root_path
logger
.
info
(
"Available routes are:"
)
...
...
vllm/entrypoints/openai/cli_args.py
View file @
500b93c8
...
...
@@ -130,6 +130,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"using app.add_middleware(). "
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
parser
.
add_argument
(
'--max-log-len'
,
type
=
int
,
default
=
None
,
help
=
'Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'
\n\n
Default: Unlimited'
)
return
parser
...
...
vllm/entrypoints/openai/protocol.py
View file @
500b93c8
...
...
@@ -3,50 +3,16 @@
import
time
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
import
openai.types.chat
import
torch
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
model_validator
# pydantic needs the TypedDict from typing_extensions
from
typing_extensions
import
Annotated
,
Required
,
TypedDict
from
typing_extensions
import
Annotated
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
random_uuid
class
CustomChatCompletionContentPartParam
(
TypedDict
,
total
=
False
):
__pydantic_config__
=
ConfigDict
(
extra
=
"allow"
)
# type: ignore
type
:
Required
[
str
]
"""The type of the content part."""
ChatCompletionContentPartParam
=
Union
[
openai
.
types
.
chat
.
ChatCompletionContentPartParam
,
CustomChatCompletionContentPartParam
]
class
CustomChatCompletionMessageParam
(
TypedDict
,
total
=
False
):
"""Enables custom roles in the Chat Completion API."""
role
:
Required
[
str
]
"""The role of the message's author."""
content
:
Union
[
str
,
List
[
ChatCompletionContentPartParam
]]
"""The contents of the message."""
name
:
str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
ChatCompletionMessageParam
=
Union
[
openai
.
types
.
chat
.
ChatCompletionMessageParam
,
CustomChatCompletionMessageParam
]
class
OpenAIBaseModel
(
BaseModel
):
# OpenAI API does not allow extra fields
model_config
=
ConfigDict
(
extra
=
"forbid"
)
...
...
@@ -155,40 +121,42 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: begin-chat-completion-sampling-params
best_of
:
Optional
[
int
]
=
None
use_beam_search
:
Optional
[
bool
]
=
False
top_k
:
Optional
[
int
]
=
-
1
min_p
:
Optional
[
float
]
=
0.0
repetition_penalty
:
Optional
[
float
]
=
1.0
length_penalty
:
Optional
[
float
]
=
1.0
early_stopping
:
Optional
[
bool
]
=
False
ignore_eos
:
Optional
[
bool
]
=
False
min_tokens
:
Optional
[
int
]
=
0
use_beam_search
:
bool
=
False
top_k
:
int
=
-
1
min_p
:
float
=
0.0
repetition_penalty
:
float
=
1.0
length_penalty
:
float
=
1.0
early_stopping
:
bool
=
False
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
skip_special_tokens
:
Optional
[
bool
]
=
True
spaces_between_special_tokens
:
Optional
[
bool
]
=
True
include_stop_str_in_output
:
bool
=
False
ignore_eos
:
bool
=
False
min_tokens
:
int
=
0
skip_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
# doc: end-chat-completion-sampling-params
# doc: begin-chat-completion-extra-params
echo
:
Optional
[
bool
]
=
Field
(
echo
:
bool
=
Field
(
default
=
False
,
description
=
(
"If true, the new message will be prepended with the last message "
"if they belong to the same role."
),
)
add_generation_prompt
:
Optional
[
bool
]
=
Field
(
add_generation_prompt
:
bool
=
Field
(
default
=
True
,
description
=
(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
add_special_tokens
:
Optional
[
bool
]
=
Field
(
add_special_tokens
:
bool
=
Field
(
default
=
False
,
description
=
(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to
F
alse (as is the "
"special tokens so this should be set to
f
alse (as is the "
"default)."
),
)
documents
:
Optional
[
List
[
Dict
[
str
,
str
]]]
=
Field
(
...
...
@@ -212,12 +180,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
description
=
(
"Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."
),
)
include_stop_str_in_output
:
Optional
[
bool
]
=
Field
(
default
=
False
,
description
=
(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."
),
)
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
Field
(
default
=
None
,
description
=
(
"If specified, the output will follow the JSON schema."
),
...
...
@@ -278,22 +240,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
return
SamplingParams
(
n
=
self
.
n
,
best_of
=
self
.
best_of
,
presence_penalty
=
self
.
presence_penalty
,
frequency_penalty
=
self
.
frequency_penalty
,
repetition_penalty
=
self
.
repetition_penalty
,
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
min_p
=
self
.
min_p
,
seed
=
self
.
seed
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
max_tokens
=
self
.
max_tokens
,
min_tokens
=
self
.
min_tokens
,
logprobs
=
self
.
top_logprobs
if
self
.
logprobs
else
None
,
prompt_logprobs
=
self
.
top_logprobs
if
self
.
echo
else
None
,
best_of
=
self
.
best_of
,
top_k
=
self
.
top_k
,
ignore_eos
=
self
.
ignore_eos
,
max_tokens
=
self
.
max_tokens
,
min_tokens
=
self
.
min_tokens
,
use_beam_search
=
self
.
use_beam_search
,
early_stopping
=
self
.
early_stopping
,
skip_special_tokens
=
self
.
skip_special_tokens
,
...
...
@@ -301,6 +263,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
include_stop_str_in_output
=
self
.
include_stop_str_in_output
,
length_penalty
=
self
.
length_penalty
,
logits_processors
=
logits_processors
,
truncate_prompt_tokens
=
self
.
truncate_prompt_tokens
,
)
@
model_validator
(
mode
=
'before'
)
...
...
@@ -382,26 +345,27 @@ class CompletionRequest(OpenAIBaseModel):
user
:
Optional
[
str
]
=
None
# doc: begin-completion-sampling-params
use_beam_search
:
Optional
[
bool
]
=
False
top_k
:
Optional
[
int
]
=
-
1
min_p
:
Optional
[
float
]
=
0.0
repetition_penalty
:
Optional
[
float
]
=
1.0
length_penalty
:
Optional
[
float
]
=
1.0
early_stopping
:
Optional
[
bool
]
=
False
use_beam_search
:
bool
=
False
top_k
:
int
=
-
1
min_p
:
float
=
0.0
repetition_penalty
:
float
=
1.0
length_penalty
:
float
=
1.0
early_stopping
:
bool
=
False
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
ignore_eos
:
Optional
[
bool
]
=
False
min_tokens
:
Optional
[
int
]
=
0
skip_special_tokens
:
Optional
[
bool
]
=
True
spaces_between_special_tokens
:
Optional
[
bool
]
=
True
include_stop_str_in_output
:
bool
=
False
ignore_eos
:
bool
=
False
min_tokens
:
int
=
0
skip_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
include_stop_str_in_output
:
Optional
[
bool
]
=
Field
(
default
=
Fals
e
,
add_special_tokens
:
bool
=
Field
(
default
=
Tru
e
,
description
=
(
"
Whether to includ
e the
stop string in the output.
"
"
This is only applied when the stop or stop_token_ids is se
t."
),
"
If tru
e
(
the
default), special tokens (e.g. BOS) will be added to
"
"
the promp
t."
),
)
response_format
:
Optional
[
ResponseFormat
]
=
Field
(
default
=
None
,
...
...
@@ -481,15 +445,15 @@ class CompletionRequest(OpenAIBaseModel):
seed
=
self
.
seed
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
logprobs
=
self
.
logprobs
,
ignore_eos
=
self
.
ignore_eos
,
max_tokens
=
self
.
max_tokens
if
not
echo_without_generation
else
1
,
min_tokens
=
self
.
min_tokens
,
logprobs
=
self
.
logprobs
,
use_beam_search
=
self
.
use_beam_search
,
early_stopping
=
self
.
early_stopping
,
prompt_logprobs
=
self
.
logprobs
if
self
.
echo
else
None
,
skip_special_tokens
=
self
.
skip_special_tokens
,
spaces_between_special_tokens
=
(
self
.
spaces_between_special_tokens
)
,
spaces_between_special_tokens
=
self
.
spaces_between_special_tokens
,
include_stop_str_in_output
=
self
.
include_stop_str_in_output
,
length_penalty
=
self
.
length_penalty
,
logits_processors
=
logits_processors
,
...
...
@@ -523,11 +487,11 @@ class CompletionRequest(OpenAIBaseModel):
def
validate_stream_options
(
cls
,
data
):
if
data
.
get
(
"stream_options"
)
and
not
data
.
get
(
"stream"
):
raise
ValueError
(
"Stream options can only be defined when stream is
T
rue."
)
"Stream options can only be defined when stream is
t
rue."
)
return
data
class
EmbeddingRequest
(
BaseModel
):
class
EmbeddingRequest
(
OpenAI
BaseModel
):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
model
:
str
...
...
@@ -599,13 +563,13 @@ class CompletionStreamResponse(OpenAIBaseModel):
usage
:
Optional
[
UsageInfo
]
=
Field
(
default
=
None
)
class
EmbeddingResponseData
(
BaseModel
):
class
EmbeddingResponseData
(
OpenAI
BaseModel
):
index
:
int
object
:
str
=
"embedding"
embedding
:
Union
[
List
[
float
],
str
]
class
EmbeddingResponse
(
BaseModel
):
class
EmbeddingResponse
(
OpenAI
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"cmpl-
{
random_uuid
()
}
"
)
object
:
str
=
"list"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
...
...
@@ -704,8 +668,8 @@ class BatchRequestInput(OpenAIBaseModel):
# /v1/chat/completions is supported.
url
:
str
# The paramete
te
rs of the request.
body
:
Union
[
ChatCompletionRequest
,
]
# The parameters of the request.
body
:
ChatCompletionRequest
class
BatchResponseData
(
OpenAIBaseModel
):
...
...
@@ -737,16 +701,28 @@ class BatchRequestOutput(OpenAIBaseModel):
error
:
Optional
[
Any
]
class
TokenizeRequest
(
OpenAIBaseModel
):
class
Tokenize
Completion
Request
(
OpenAIBaseModel
):
model
:
str
prompt
:
str
add_special_tokens
:
bool
=
Field
(
default
=
True
)
class
TokenizeChatRequest
(
OpenAIBaseModel
):
model
:
str
messages
:
List
[
ChatCompletionMessageParam
]
add_generation_prompt
:
bool
=
Field
(
default
=
True
)
add_special_tokens
:
bool
=
Field
(
default
=
False
)
TokenizeRequest
=
Union
[
TokenizeCompletionRequest
,
TokenizeChatRequest
]
class
TokenizeResponse
(
OpenAIBaseModel
):
tokens
:
List
[
int
]
count
:
int
max_model_len
:
int
tokens
:
List
[
int
]
class
DetokenizeRequest
(
OpenAIBaseModel
):
...
...
vllm/entrypoints/openai/run_batch.py
View file @
500b93c8
...
...
@@ -6,6 +6,7 @@ import aiohttp
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
nullable_str
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
BatchRequestInput
,
BatchRequestOutput
,
BatchResponseData
,
...
...
@@ -44,9 +45,17 @@ def parse_args():
type
=
nullable_str
,
default
=
"assistant"
,
help
=
"The role name to return if "
"`request.add_generation_prompt=
t
rue`."
)
"`request.add_generation_prompt=
T
rue`."
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
parser
.
add_argument
(
'--max-log-len'
,
type
=
int
,
default
=
None
,
help
=
'Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'
\n\n
Default: Unlimited'
)
return
parser
.
parse_args
()
...
...
@@ -114,11 +123,20 @@ async def main(args):
# When using single vLLM without engine_use_ray
model_config
=
await
engine
.
get_model_config
()
if
args
.
disable_log_requests
:
request_logger
=
None
else
:
request_logger
=
RequestLogger
(
max_log_len
=
args
.
max_log_len
)
openai_serving_chat
=
OpenAIServingChat
(
engine
,
model_config
,
served_model_names
,
args
.
response_role
,
lora_modules
=
None
,
prompt_adapters
=
None
,
request_logger
=
request_logger
,
chat_template
=
None
,
)
# Submit all requests in the file to the engine "concurrently".
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
500b93c8
import
codecs
import
time
from
dataclasses
import
dataclass
,
field
from
functools
import
cached_property
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Awaitable
,
Dict
,
Iterable
,
List
,
Optional
)
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Awaitable
,
Dict
,
List
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
TypedDict
,
Union
,
cast
,
final
from
typing
import
Union
from
fastapi
import
Request
from
openai.types.chat
import
(
ChatCompletionContentPartImageParam
,
ChatCompletionContentPartTextParam
)
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
load_chat_template
,
parse_chat_message_content
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionContentPartParam
,
ChatCompletionLogProb
,
ChatCompletionLogProbs
,
ChatCompletionLogProbsContent
,
ChatCompletionMessageParam
,
ChatCompletionNamedToolChoiceParam
,
ChatCompletionLogProb
,
ChatCompletionLogProbs
,
ChatCompletionLogProbsContent
,
ChatCompletionNamedToolChoiceParam
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
FunctionCall
,
ToolCall
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
)
OpenAIServing
,
PromptAdapterPath
)
from
vllm.inputs
import
PromptInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal.utils
import
async_get_and_parse_image
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
...
...
@@ -38,159 +37,31 @@ from vllm.utils import random_uuid
logger
=
init_logger
(
__name__
)
@
final
# So that it should be compatible with Dict[str, str]
class
ConversationMessage
(
TypedDict
):
role
:
str
content
:
str
@
dataclass
(
frozen
=
True
)
class
ChatMessageParseResult
:
messages
:
List
[
ConversationMessage
]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
field
(
default_factory
=
list
)
class
OpenAIServingChat
(
OpenAIServing
):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
response_role
:
str
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]]
=
None
,
chat_template
:
Optional
[
str
]
=
None
):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
response_role
:
str
,
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
request_logger
:
Optional
[
RequestLogger
],
chat_template
:
Optional
[
str
],
):
super
().
__init__
(
engine
=
engine
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
lora_modules
=
lora_modules
)
lora_modules
=
lora_modules
,
prompt_adapters
=
prompt_adapters
,
request_logger
=
request_logger
)
self
.
response_role
=
response_role
self
.
_load_chat_template
(
chat_template
)
def
_load_chat_template
(
self
,
chat_template
:
Optional
[
str
]):
tokenizer
=
self
.
tokenizer
if
chat_template
is
not
None
:
try
:
with
open
(
chat_template
,
"r"
)
as
f
:
tokenizer
.
chat_template
=
f
.
read
()
except
OSError
as
e
:
JINJA_CHARS
=
"{}
\n
"
if
not
any
(
c
in
chat_template
for
c
in
JINJA_CHARS
):
msg
=
(
f
"The supplied chat template (
{
chat_template
}
) "
f
"looks like a file path, but it failed to be "
f
"opened. Reason:
{
e
}
"
)
raise
ValueError
(
msg
)
from
e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer
.
chat_template
=
codecs
.
decode
(
chat_template
,
"unicode_escape"
)
logger
.
info
(
"Using supplied chat template:
\n
%s"
,
tokenizer
.
chat_template
)
elif
tokenizer
.
chat_template
is
not
None
:
logger
.
info
(
"Using default chat template:
\n
%s"
,
tokenizer
.
chat_template
)
else
:
logger
.
warning
(
"No chat template provided. Chat API will not work."
)
@
cached_property
def
image_token_str
(
self
)
->
Optional
[
str
]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type
=
self
.
model_config
.
hf_config
.
model_type
if
model_type
==
"phi3_v"
:
# Workaround since this token is not defined in the tokenizer
return
"<|image_1|>"
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"minicpmv"
,
"paligemma"
):
# These models do not use image tokens in the prompt
return
None
if
model_type
.
startswith
(
"llava"
):
return
self
.
tokenizer
.
decode
(
self
.
model_config
.
hf_config
.
image_token_index
)
else
:
raise
TypeError
(
"Unknown model type: {model_type}"
)
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
def
_get_full_image_text_prompt
(
self
,
image_token_str
:
str
,
text_prompt
:
str
)
->
str
:
"""Combine image and text prompts for vision language model"""
# NOTE: For now we assume all model architectures use the same
# image + text prompt format. This may change in the future.
return
f
"
{
image_token_str
}
\n
{
text_prompt
}
"
def
_parse_chat_message_content_parts
(
self
,
role
:
str
,
parts
:
Iterable
[
ChatCompletionContentPartParam
],
)
->
ChatMessageParseResult
:
texts
:
List
[
str
]
=
[]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
[]
for
part
in
parts
:
part_type
=
part
[
"type"
]
if
part_type
==
"text"
:
text
=
cast
(
ChatCompletionContentPartTextParam
,
part
)[
"text"
]
texts
.
append
(
text
)
elif
part_type
==
"image_url"
:
if
len
(
mm_futures
)
>
0
:
raise
NotImplementedError
(
"Multiple 'image_url' input is currently not supported."
)
image_url
=
cast
(
ChatCompletionContentPartImageParam
,
part
)[
"image_url"
]
if
image_url
.
get
(
"detail"
,
"auto"
)
!=
"auto"
:
logger
.
warning
(
"'image_url.detail' is currently not supported and "
"will be ignored."
)
image_future
=
async_get_and_parse_image
(
image_url
[
"url"
])
mm_futures
.
append
(
image_future
)
else
:
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
text_prompt
=
"
\n
"
.
join
(
texts
)
if
mm_futures
:
image_token_str
=
self
.
image_token_str
if
image_token_str
is
not
None
:
if
image_token_str
in
text_prompt
:
logger
.
warning
(
"Detected image token string in the text prompt. "
"Skipping prompt formatting."
)
else
:
text_prompt
=
self
.
_get_full_image_text_prompt
(
image_token_str
=
image_token_str
,
text_prompt
=
text_prompt
,
)
messages
=
[
ConversationMessage
(
role
=
role
,
content
=
text_prompt
)]
return
ChatMessageParseResult
(
messages
=
messages
,
mm_futures
=
mm_futures
)
def
_parse_chat_message_content
(
self
,
message
:
ChatCompletionMessageParam
,
)
->
ChatMessageParseResult
:
role
=
message
[
"role"
]
content
=
message
.
get
(
"content"
)
if
content
is
None
:
return
ChatMessageParseResult
(
messages
=
[],
mm_futures
=
[])
if
isinstance
(
content
,
str
):
messages
=
[
ConversationMessage
(
role
=
role
,
content
=
content
)]
return
ChatMessageParseResult
(
messages
=
messages
,
mm_futures
=
[])
return
self
.
_parse_chat_message_content_parts
(
role
,
content
)
# If this is None we use the tokenizer's default chat template
self
.
chat_template
=
load_chat_template
(
chat_template
)
async
def
create_chat_completion
(
self
,
...
...
@@ -212,11 +83,20 @@ class OpenAIServingChat(OpenAIServing):
return
error_check_ret
try
:
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
model_config
=
self
.
model_config
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
conversation
:
List
[
ConversationMessage
]
=
[]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
[]
for
msg
in
request
.
messages
:
chat_parsed_result
=
self
.
_parse_chat_message_content
(
msg
)
chat_parsed_result
=
parse_chat_message_content
(
msg
,
model_config
,
tokenizer
)
conversation
.
extend
(
chat_parsed_result
.
messages
)
mm_futures
.
extend
(
chat_parsed_result
.
mm_futures
)
...
...
@@ -225,13 +105,13 @@ class OpenAIServingChat(OpenAIServing):
tool
.
model_dump
()
for
tool
in
request
.
tools
]
prompt
=
self
.
tokenizer
.
apply_chat_template
(
prompt
=
tokenizer
.
apply_chat_template
(
conversation
=
conversation
,
tokenize
=
False
,
add_generation_prompt
=
request
.
add_generation_prompt
,
tools
=
tool_dicts
,
documents
=
request
.
documents
,
chat_template
=
request
.
chat_template
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
**
(
request
.
chat_template_kwargs
or
{}),
)
except
Exception
as
e
:
...
...
@@ -250,61 +130,71 @@ class OpenAIServingChat(OpenAIServing):
logger
.
error
(
"Error in loading multi-modal data: %s"
,
e
)
return
self
.
create_error_response
(
str
(
e
))
request_id
=
f
"c
mpl
-
{
random_uuid
()
}
"
request_id
=
f
"c
hat
-
{
random_uuid
()
}
"
try
:
# Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids
,
prompt_text
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
prompt
,
add_special_tokens
=
request
.
add_special_tokens
)
sampling_params
=
request
.
to_sampling_params
()
_
,
lora_request
=
self
.
_maybe_get_adapter
(
request
)
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
guided_decode_logits_processor
=
(
await
get_guided_decoding_logits_processor
(
guided_decoding_backend
,
request
,
await
self
.
engine
.
get_
tokenizer
()
))
await
get_guided_decoding_logits_processor
(
guided_decoding_backend
,
request
,
tokenizer
))
if
guided_decode_logits_processor
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
sampling_params
.
logits_processors
.
append
(
guided_decode_logits_processor
)
prompt_inputs
=
self
.
_tokenize_prompt_input
(
request
,
tokenizer
,
prompt
,
truncate_prompt_tokens
=
sampling_params
.
truncate_prompt_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
)
self
.
_log_inputs
(
request_id
,
prompt_inputs
,
params
=
sampling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
engine_inputs
:
PromptInputs
=
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
],
}
if
mm_data
is
not
None
:
engine_inputs
[
"multi_modal_data"
]
=
mm_data
is_tracing_enabled
=
await
self
.
engine
.
is_tracing_enabled
()
trace_headers
=
None
if
is_tracing_enabled
and
raw_request
:
trace_headers
=
extract_trace_headers
(
raw_request
.
headers
)
if
(
not
is_tracing_enabled
and
raw_request
and
contains_trace_headers
(
raw_request
.
headers
)):
log_tracing_disabled_warning
()
result_generator
=
self
.
engine
.
generate
(
engine_inputs
,
sampling_params
,
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
inputs
:
PromptInputs
=
{
"prompt"
:
prompt_text
,
"prompt_token_ids"
:
prompt_ids
,
}
if
mm_data
:
inputs
[
"multi_modal_data"
]
=
mm_data
is_tracing_enabled
=
await
self
.
engine
.
is_tracing_enabled
()
trace_headers
=
None
if
is_tracing_enabled
and
raw_request
:
trace_headers
=
extract_trace_headers
(
raw_request
.
headers
)
if
not
is_tracing_enabled
and
raw_request
and
contains_trace_headers
(
raw_request
.
headers
):
log_tracing_disabled_warning
()
result_generator
=
self
.
engine
.
generate
(
inputs
,
sampling_params
,
request_id
,
lora_request
,
trace_headers
=
trace_headers
,
)
# Streaming response
if
request
.
stream
:
return
self
.
chat_completion_stream_generator
(
request
,
result_generator
,
request_id
,
conversation
)
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
)
else
:
try
:
return
await
self
.
chat_completion_full_generator
(
request
,
raw_request
,
result_generator
,
request_id
,
conversation
)
conversation
,
tokenizer
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
...
...
@@ -316,9 +206,12 @@ class OpenAIServingChat(OpenAIServing):
return
request
.
messages
[
-
1
][
"role"
]
async
def
chat_completion_stream_generator
(
self
,
request
:
ChatCompletionRequest
,
result_generator
:
AsyncIterator
[
RequestOutput
],
request_id
:
str
,
conversation
:
List
[
ConversationMessage
]
self
,
request
:
ChatCompletionRequest
,
result_generator
:
AsyncIterator
[
RequestOutput
],
request_id
:
str
,
conversation
:
List
[
ConversationMessage
],
tokenizer
:
PreTrainedTokenizer
,
)
->
AsyncGenerator
[
str
,
None
]:
model_name
=
self
.
served_model_names
[
0
]
created_time
=
int
(
time
.
time
())
...
...
@@ -326,10 +219,11 @@ class OpenAIServingChat(OpenAIServing):
first_iteration
=
True
# Send response for each token for each request.n (index)
assert
request
.
n
is
not
None
previous_texts
=
[
""
]
*
request
.
n
previous_num_tokens
=
[
0
]
*
request
.
n
finish_reason_sent
=
[
False
]
*
request
.
n
num_choices
=
1
if
request
.
n
is
None
else
request
.
n
previous_texts
=
[
""
]
*
num_choices
previous_num_tokens
=
[
0
]
*
num_choices
finish_reason_sent
=
[
False
]
*
num_choices
try
:
async
for
res
in
result_generator
:
# We need to do it here, because if there are exceptions in
...
...
@@ -339,7 +233,7 @@ class OpenAIServingChat(OpenAIServing):
# Send first response for each request.n (index) with
# the role
role
=
self
.
get_chat_request_role
(
request
)
for
i
in
range
(
request
.
n
):
for
i
in
range
(
num_choices
):
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
role
=
role
),
...
...
@@ -367,19 +261,19 @@ class OpenAIServingChat(OpenAIServing):
last_msg_content
=
conversation
[
-
1
][
"content"
]
if
last_msg_content
:
for
i
in
range
(
request
.
n
):
for
i
in
range
(
num_choices
):
choice_data
=
(
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
content
=
last_msg_content
),
logprobs
=
None
,
finish_reason
=
None
))
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
logprobs
=
None
,
model
=
model_name
)
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
...
...
@@ -405,6 +299,7 @@ class OpenAIServingChat(OpenAIServing):
logprobs
=
self
.
_create_chat_logprobs
(
token_ids
=
delta_token_ids
,
top_logprobs
=
out_logprobs
,
tokenizer
=
tokenizer
,
num_output_top_logprobs
=
request
.
top_logprobs
,
)
else
:
...
...
@@ -493,9 +388,13 @@ class OpenAIServingChat(OpenAIServing):
yield
"data: [DONE]
\n\n
"
async
def
chat_completion_full_generator
(
self
,
request
:
ChatCompletionRequest
,
raw_request
:
Optional
[
Request
],
result_generator
:
AsyncIterator
[
RequestOutput
],
request_id
:
str
,
conversation
:
List
[
ConversationMessage
]
self
,
request
:
ChatCompletionRequest
,
raw_request
:
Optional
[
Request
],
result_generator
:
AsyncIterator
[
RequestOutput
],
request_id
:
str
,
conversation
:
List
[
ConversationMessage
],
tokenizer
:
PreTrainedTokenizer
,
)
->
Union
[
ErrorResponse
,
ChatCompletionResponse
]:
model_name
=
self
.
served_model_names
[
0
]
...
...
@@ -523,6 +422,7 @@ class OpenAIServingChat(OpenAIServing):
token_ids
=
token_ids
,
top_logprobs
=
out_logprobs
,
num_output_top_logprobs
=
request
.
top_logprobs
,
tokenizer
=
tokenizer
,
)
else
:
logprobs
=
None
...
...
@@ -577,16 +477,14 @@ class OpenAIServingChat(OpenAIServing):
return
response
def
_get_top_logprobs
(
self
,
logprobs
:
Dict
[
int
,
Logprob
],
to
p_logprobs
:
Optional
[
int
]
)
->
List
[
ChatCompletionLogProb
]:
self
,
logprobs
:
Dict
[
int
,
Logprob
],
top_logprobs
:
Optional
[
int
],
to
kenizer
:
PreTrainedTokenizer
)
->
List
[
ChatCompletionLogProb
]:
return
[
ChatCompletionLogProb
(
token
=
self
.
_get_decoded_token
(
p
[
1
],
p
[
0
]),
token
=
(
token
:
=
self
.
_get_decoded_token
(
p
[
1
],
p
[
0
],
tokenizer
)),
logprob
=
max
(
p
[
1
].
logprob
,
-
9999.0
),
bytes
=
list
(
self
.
_get_decoded_token
(
p
[
1
],
p
[
0
]).
encode
(
"utf-8"
,
errors
=
"replace"
)))
bytes
=
list
(
token
.
encode
(
"utf-8"
,
errors
=
"replace"
)))
for
i
,
p
in
enumerate
(
logprobs
.
items
())
if
top_logprobs
and
i
<
top_logprobs
]
...
...
@@ -595,6 +493,7 @@ class OpenAIServingChat(OpenAIServing):
self
,
token_ids
:
GenericSequence
[
int
],
top_logprobs
:
GenericSequence
[
Optional
[
Dict
[
int
,
Logprob
]]],
tokenizer
:
PreTrainedTokenizer
,
num_output_top_logprobs
:
Optional
[
int
]
=
None
,
)
->
ChatCompletionLogProbs
:
"""Create OpenAI-style logprobs."""
...
...
@@ -604,12 +503,11 @@ class OpenAIServingChat(OpenAIServing):
for
i
,
token_id
in
enumerate
(
token_ids
):
step_top_logprobs
=
top_logprobs
[
i
]
if
step_top_logprobs
is
None
:
token
=
tokenizer
.
decode
(
token_id
)
logprobs_content
.
append
(
ChatCompletionLogProbsContent
(
token
=
self
.
tokenizer
.
decode
(
token_id
),
bytes
=
list
(
self
.
tokenizer
.
decode
(
token_id
).
encode
(
"utf-8"
,
errors
=
"replace"
))))
token
=
token
,
bytes
=
list
(
token
.
encode
(
"utf-8"
,
errors
=
"replace"
))))
else
:
logprobs_content
.
append
(
ChatCompletionLogProbsContent
(
...
...
@@ -620,6 +518,7 @@ class OpenAIServingChat(OpenAIServing):
step_top_logprobs
[
token_id
].
decoded_token
.
encode
(
"utf-8"
,
errors
=
"replace"
)),
top_logprobs
=
self
.
_get_top_logprobs
(
step_top_logprobs
,
num_output_top_logprobs
)))
step_top_logprobs
,
num_output_top_logprobs
,
tokenizer
)))
return
ChatCompletionLogProbs
(
content
=
logprobs_content
)
vllm/entrypoints/openai/serving_completion.py
View file @
500b93c8
...
...
@@ -2,12 +2,14 @@ import time
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Callable
,
Dict
,
List
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Tuple
from
typing
import
Tuple
,
cast
from
fastapi
import
Request
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.entrypoints.openai.protocol
import
(
CompletionLogProbs
,
...
...
@@ -16,10 +18,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
DetokenizeRequest
,
DetokenizeResponse
,
TokenizeRequest
,
TokenizeResponse
,
UsageInfo
)
UsageInfo
)
# yapf: enable
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
,
...
...
@@ -41,40 +40,24 @@ TypeCreateLogProbsFn = Callable[
[
TypeTokenIDs
,
TypeTopLogProbs
,
Optional
[
int
],
int
],
CompletionLogProbs
]
def
parse_prompt_format
(
prompt
)
->
Tuple
[
bool
,
list
]:
# get the prompt, openai supports the following
# "a string, array of strings, array of tokens, or array of token arrays."
prompt_is_tokens
=
False
prompts
=
[
prompt
]
# case 1: a string
if
isinstance
(
prompt
,
list
):
if
len
(
prompt
)
==
0
:
raise
ValueError
(
"please provide at least one prompt"
)
elif
isinstance
(
prompt
[
0
],
str
):
prompt_is_tokens
=
False
prompts
=
prompt
# case 2: array of strings
elif
isinstance
(
prompt
[
0
],
int
):
prompt_is_tokens
=
True
prompts
=
[
prompt
]
# case 3: array of tokens
elif
isinstance
(
prompt
[
0
],
list
)
and
isinstance
(
prompt
[
0
][
0
],
int
):
prompt_is_tokens
=
True
prompts
=
prompt
# case 4: array of token arrays
else
:
raise
ValueError
(
"prompt must be a string, array of strings, "
"array of tokens, or array of token arrays"
)
return
prompt_is_tokens
,
prompts
class
OpenAIServingCompletion
(
OpenAIServing
):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]]):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
request_logger
:
Optional
[
RequestLogger
],
):
super
().
__init__
(
engine
=
engine
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
lora_modules
=
lora_modules
,
prompt_adapters
=
prompt_adapters
)
prompt_adapters
=
prompt_adapters
,
request_logger
=
request_logger
)
async
def
create_completion
(
self
,
request
:
CompletionRequest
,
raw_request
:
Request
):
...
...
@@ -103,41 +86,45 @@ class OpenAIServingCompletion(OpenAIServing):
# Schedule the request and get the result generator.
generators
:
List
[
AsyncIterator
[
RequestOutput
]]
=
[]
try
:
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
sampling_params
=
request
.
to_sampling_params
()
adapter_type
,
adapter_request
=
self
.
_maybe_get_adapter
(
request
)
lora_request
,
prompt_adapter_request
=
None
,
None
if
adapter_type
==
'LoRA'
:
lora_request
,
prompt_adapter_request
=
adapter_request
,
None
elif
adapter_type
==
'PromptAdapter'
:
lora_request
,
prompt_adapter_request
=
None
,
adapter_request
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
guided_decode_logit_processor
=
(
await
get_guided_decoding_logits_processor
(
guided_decoding_backend
,
request
,
await
self
.
engine
.
get_
tokenizer
()
))
await
get_guided_decoding_logits_processor
(
guided_decoding_backend
,
request
,
tokenizer
))
if
guided_decode_logit_processor
is
not
None
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
sampling_params
.
logits_processors
.
append
(
guided_decode_logit_processor
)
prompt_is_tokens
,
prompts
=
parse_prompt_format
(
request
.
prompt
)
for
i
,
prompt
in
enumerate
(
prompts
):
if
prompt_is_tokens
:
prompt_formats
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt_ids
=
prompt
,
truncate_prompt_tokens
=
sampling_params
.
truncate_prompt_tokens
)
else
:
prompt_formats
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
prompt
,
truncate_prompt_tokens
=
sampling_params
.
truncate_prompt_tokens
)
prompt_ids
,
prompt_text
=
prompt_formats
prompts
=
list
(
self
.
_tokenize_prompt_input_or_inputs
(
request
,
tokenizer
,
request
.
prompt
,
truncate_prompt_tokens
=
sampling_params
.
truncate_prompt_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
))
for
i
,
prompt_inputs
in
enumerate
(
prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
self
.
_log_inputs
(
request_id_item
,
prompt_inputs
,
params
=
sampling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
is_tracing_enabled
=
await
self
.
engine
.
is_tracing_enabled
()
trace_headers
=
None
...
...
@@ -148,12 +135,9 @@ class OpenAIServingCompletion(OpenAIServing):
log_tracing_disabled_warning
()
generator
=
self
.
engine
.
generate
(
{
"prompt"
:
prompt_text
,
"prompt_token_ids"
:
prompt_ids
},
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]},
sampling_params
,
f
"
{
request_id
}
-
{
i
}
"
,
request_id
_item
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
,
...
...
@@ -182,7 +166,8 @@ class OpenAIServingCompletion(OpenAIServing):
request_id
,
created_time
,
model_name
,
num_prompts
=
len
(
prompts
))
num_prompts
=
len
(
prompts
),
tokenizer
=
tokenizer
)
# Non-streaming response
final_res_batch
:
List
[
Optional
[
RequestOutput
]]
=
[
None
]
*
len
(
prompts
)
...
...
@@ -193,8 +178,27 @@ class OpenAIServingCompletion(OpenAIServing):
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
return
self
.
create_error_response
(
"Client disconnected"
)
final_res_batch
[
i
]
=
res
for
i
,
final_res
in
enumerate
(
final_res_batch
):
assert
final_res
is
not
None
# The output should contain the input text
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if
final_res
.
prompt
is
None
:
final_res
.
prompt
=
prompts
[
i
][
"prompt"
]
final_res_batch_checked
=
cast
(
List
[
RequestOutput
],
final_res_batch
)
response
=
self
.
request_output_to_completion_response
(
final_res_batch
,
request
,
request_id
,
created_time
,
model_name
)
final_res_batch_checked
,
request
,
request_id
,
created_time
,
model_name
,
tokenizer
,
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
...
...
@@ -221,11 +225,12 @@ class OpenAIServingCompletion(OpenAIServing):
created_time
:
int
,
model_name
:
str
,
num_prompts
:
int
,
tokenizer
:
PreTrainedTokenizer
,
)
->
AsyncGenerator
[
str
,
None
]:
assert
request
.
n
is
not
None
previous_texts
=
[
""
]
*
request
.
n
*
num_prompts
previous_num_tokens
=
[
0
]
*
request
.
n
*
num_prompts
has_echoed
=
[
False
]
*
request
.
n
*
num_prompts
num_choices
=
1
if
request
.
n
is
None
else
request
.
n
previous_texts
=
[
""
]
*
num_choices
*
num_prompts
previous_num_tokens
=
[
0
]
*
num_choices
*
num_prompts
has_echoed
=
[
False
]
*
num_choices
*
num_prompts
try
:
async
for
prompt_idx
,
res
in
result_generator
:
...
...
@@ -236,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing):
raise
StopAsyncIteration
()
for
output
in
res
.
outputs
:
i
=
output
.
index
+
prompt_idx
*
request
.
n
i
=
output
.
index
+
prompt_idx
*
num_choices
# TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending.
...
...
@@ -271,6 +276,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids
=
delta_token_ids
,
top_logprobs
=
out_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
tokenizer
=
tokenizer
,
initial_text_offset
=
len
(
previous_texts
[
i
]),
)
else
:
...
...
@@ -339,12 +345,13 @@ class OpenAIServingCompletion(OpenAIServing):
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
tokenizer
:
PreTrainedTokenizer
,
)
->
CompletionResponse
:
choices
:
List
[
CompletionResponseChoice
]
=
[]
num_prompt_tokens
=
0
num_generated_tokens
=
0
for
final_res
in
final_res_batch
:
assert
final_res
is
not
None
prompt_token_ids
=
final_res
.
prompt_token_ids
prompt_logprobs
=
final_res
.
prompt_logprobs
prompt_text
=
final_res
.
prompt
...
...
@@ -370,6 +377,7 @@ class OpenAIServingCompletion(OpenAIServing):
logprobs
=
self
.
_create_completion_logprobs
(
token_ids
=
token_ids
,
top_logprobs
=
out_logprobs
,
tokenizer
=
tokenizer
,
num_output_top_logprobs
=
request
.
logprobs
,
)
else
:
...
...
@@ -407,6 +415,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids
:
GenericSequence
[
int
],
top_logprobs
:
GenericSequence
[
Optional
[
Dict
[
int
,
Logprob
]]],
num_output_top_logprobs
:
int
,
tokenizer
:
PreTrainedTokenizer
,
initial_text_offset
:
int
=
0
,
)
->
CompletionLogProbs
:
"""Create logprobs for OpenAI Completion API."""
...
...
@@ -420,13 +429,13 @@ class OpenAIServingCompletion(OpenAIServing):
for
i
,
token_id
in
enumerate
(
token_ids
):
step_top_logprobs
=
top_logprobs
[
i
]
if
step_top_logprobs
is
None
:
token
=
self
.
tokenizer
.
decode
(
token_id
)
token
=
tokenizer
.
decode
(
token_id
)
out_tokens
.
append
(
token
)
out_token_logprobs
.
append
(
None
)
out_top_logprobs
.
append
(
None
)
else
:
token
=
self
.
_get_decoded_token
(
step_top_logprobs
[
token_id
],
token_id
)
token_id
,
tokenizer
)
token_logprob
=
max
(
step_top_logprobs
[
token_id
].
logprob
,
-
9999.0
)
out_tokens
.
append
(
token
)
...
...
@@ -439,7 +448,7 @@ class OpenAIServingCompletion(OpenAIServing):
out_top_logprobs
.
append
({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self
.
_get_decoded_token
(
top_lp
[
1
],
top_lp
[
0
]):
self
.
_get_decoded_token
(
top_lp
[
1
],
top_lp
[
0
]
,
tokenizer
):
max
(
top_lp
[
1
].
logprob
,
-
9999.0
)
for
i
,
top_lp
in
enumerate
(
step_top_logprobs
.
items
())
if
num_output_top_logprobs
>=
i
...
...
@@ -457,29 +466,3 @@ class OpenAIServingCompletion(OpenAIServing):
tokens
=
out_tokens
,
top_logprobs
=
out_top_logprobs
,
)
async
def
create_tokenize
(
self
,
request
:
TokenizeRequest
)
->
TokenizeResponse
:
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
(
input_ids
,
input_text
)
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
request
.
prompt
,
add_special_tokens
=
request
.
add_special_tokens
)
return
TokenizeResponse
(
tokens
=
input_ids
,
count
=
len
(
input_ids
),
max_model_len
=
self
.
max_model_len
)
async
def
create_detokenize
(
self
,
request
:
DetokenizeRequest
)
->
DetokenizeResponse
:
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
(
input_ids
,
input_text
)
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt_ids
=
request
.
tokens
)
return
DetokenizeResponse
(
prompt
=
input_text
)
vllm/entrypoints/openai/serving_embedding.py
View file @
500b93c8
import
base64
import
time
from
typing
import
AsyncIterator
,
List
,
Optional
,
Tuple
from
typing
import
AsyncIterator
,
List
,
Optional
,
Tuple
,
cast
import
numpy
as
np
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
EmbeddingRequest
,
EmbeddingResponse
,
EmbeddingResponseData
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_completion
import
parse_prompt_format
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.logger
import
init_logger
from
vllm.outputs
import
EmbeddingRequestOutput
...
...
@@ -28,11 +28,11 @@ def request_output_to_embedding_response(
data
:
List
[
EmbeddingResponseData
]
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
assert
final_res
is
not
None
prompt_token_ids
=
final_res
.
prompt_token_ids
embedding
=
final_res
.
outputs
.
embedding
if
encoding_format
==
"base64"
:
embedding
=
base64
.
b64encode
(
np
.
array
(
embedding
))
embedding_bytes
=
np
.
array
(
embedding
).
tobytes
()
embedding
=
base64
.
b64encode
(
embedding_bytes
).
decode
(
"utf-8"
)
embedding_data
=
EmbeddingResponseData
(
index
=
idx
,
embedding
=
embedding
)
data
.
append
(
embedding_data
)
...
...
@@ -54,12 +54,20 @@ def request_output_to_embedding_response(
class
OpenAIServingEmbedding
(
OpenAIServing
):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
]):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
*
,
request_logger
:
Optional
[
RequestLogger
],
):
super
().
__init__
(
engine
=
engine
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
lora_modules
=
None
)
lora_modules
=
None
,
prompt_adapters
=
None
,
request_logger
=
request_logger
)
self
.
_check_embedding_mode
(
model_config
.
embedding_mode
)
async
def
create_embedding
(
self
,
request
:
EmbeddingRequest
,
...
...
@@ -80,32 +88,47 @@ class OpenAIServingEmbedding(OpenAIServing):
"dimensions is currently not supported"
)
model_name
=
request
.
model
request_id
=
f
"
cmpl
-
{
random_uuid
()
}
"
request_id
=
f
"
embd
-
{
random_uuid
()
}
"
created_time
=
int
(
time
.
monotonic
())
# Schedule the request and get the result generator.
generators
=
[]
generators
:
List
[
AsyncIterator
[
EmbeddingRequestOutput
]]
=
[]
try
:
prompt_is_tokens
,
prompts
=
parse_prompt_format
(
request
.
input
)
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
pooling_params
=
request
.
to_pooling_params
()
for
i
,
prompt
in
enumerate
(
prompts
):
if
prompt_is_tokens
:
prompt_formats
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt_ids
=
prompt
)
else
:
prompt_formats
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
prompt
)
prompts
=
list
(
self
.
_tokenize_prompt_input_or_inputs
(
request
,
tokenizer
,
request
.
input
,
))
for
i
,
prompt_inputs
in
enumerate
(
prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
self
.
_log_inputs
(
request_id_item
,
prompt_inputs
,
params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_ids
,
prompt_text
=
prompt_formats
if
prompt_adapter_request
is
not
None
:
raise
NotImplementedError
(
"Prompt adapter is not supported "
"for embedding models"
)
generator
=
self
.
engine
.
encode
(
{
"prompt"
:
prompt_text
,
"prompt_token_ids"
:
prompt_ids
},
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]},
pooling_params
,
f
"
{
request_id
}
-
{
i
}
"
,
request_id_item
,
lora_request
=
lora_request
,
)
generators
.
append
(
generator
)
...
...
@@ -124,11 +147,17 @@ class OpenAIServingEmbedding(OpenAIServing):
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
"Client disconnected"
)
final_res_batch
[
i
]
=
res
for
final_res
in
final_res_batch
:
assert
final_res
is
not
None
final_res_batch_checked
=
cast
(
List
[
EmbeddingRequestOutput
],
final_res_batch
)
response
=
request_output_to_embedding_response
(
final_res_batch
,
request_id
,
created_time
,
model_name
,
final_res_batch
_checked
,
request_id
,
created_time
,
model_name
,
encoding_format
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
500b93c8
import
json
import
pathlib
from
dataclasses
import
dataclass
from
http
import
HTTPStatus
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
,
TypedDict
,
Union
from
pydantic
import
Field
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
typing_extensions
import
Annotated
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
,
DetokenizeRequest
,
EmbeddingRequest
,
ErrorResponse
,
ModelCard
,
ModelList
,
ModelPermission
,
TokenizeRequest
)
ModelPermission
,
TokenizeChatRequest
,
TokenizeCompletionRequest
,
TokenizeRequest
)
# yapf: enable
from
vllm.inputs
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
logger
=
init_logger
(
__name__
)
...
...
@@ -32,7 +43,18 @@ class PromptAdapterPath:
@
dataclass
class
LoRAModulePath
:
name
:
str
local_path
:
str
path
:
str
AnyRequest
=
Union
[
ChatCompletionRequest
,
CompletionRequest
,
DetokenizeRequest
,
EmbeddingRequest
,
TokenizeRequest
]
AnyTokenizer
=
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
class
TextTokensPrompt
(
TypedDict
):
prompt
:
str
prompt_token_ids
:
List
[
int
]
class
OpenAIServing
:
...
...
@@ -42,8 +64,10 @@ class OpenAIServing:
engine
:
AsyncLLMEngine
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]]
=
None
,
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
request_logger
:
Optional
[
RequestLogger
],
):
super
().
__init__
()
...
...
@@ -51,14 +75,6 @@ class OpenAIServing:
self
.
model_config
=
model_config
self
.
max_model_len
=
model_config
.
max_model_len
# A separate tokenizer to map token IDs to strings.
self
.
tokenizer
=
get_tokenizer
(
model_config
.
tokenizer
,
tokenizer_mode
=
model_config
.
tokenizer_mode
,
tokenizer_revision
=
model_config
.
tokenizer_revision
,
trust_remote_code
=
model_config
.
trust_remote_code
,
truncation_side
=
"left"
)
self
.
served_model_names
=
served_model_names
self
.
lora_requests
=
[]
...
...
@@ -67,15 +83,15 @@ class OpenAIServing:
LoRARequest
(
lora_name
=
lora
.
name
,
lora_int_id
=
i
,
lora_
local_
path
=
lora
.
local_
path
,
lora_path
=
lora
.
path
,
)
for
i
,
lora
in
enumerate
(
lora_modules
,
start
=
1
)
]
self
.
prompt_adapter_requests
=
[]
if
prompt_adapters
is
not
None
:
for
i
,
prompt_adapter
in
enumerate
(
prompt_adapters
,
start
=
1
):
with
open
(
f
"./
{
prompt_adapter
.
local_path
}
"
f
"/
adapter_config.json"
)
as
f
:
with
pathlib
.
Path
(
prompt_adapter
.
local_path
,
"
adapter_config.json"
)
.
open
()
as
f
:
adapter_config
=
json
.
load
(
f
)
num_virtual_tokens
=
adapter_config
[
"num_virtual_tokens"
]
self
.
prompt_adapter_requests
.
append
(
...
...
@@ -85,6 +101,8 @@ class OpenAIServing:
prompt_adapter_local_path
=
prompt_adapter
.
local_path
,
prompt_adapter_num_virtual_tokens
=
num_virtual_tokens
))
self
.
request_logger
=
request_logger
async
def
show_available_models
(
self
)
->
ModelList
:
"""Show available models. Right now we only have one model."""
model_cards
=
[
...
...
@@ -133,9 +151,8 @@ class OpenAIServing:
return
json_str
async
def
_check_model
(
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
,
DetokenizeRequest
,
EmbeddingRequest
,
TokenizeRequest
]
self
,
request
:
AnyRequest
,
)
->
Optional
[
ErrorResponse
]:
if
request
.
model
in
self
.
served_model_names
:
return
None
...
...
@@ -151,62 +168,65 @@ class OpenAIServing:
err_type
=
"NotFoundError"
,
status_code
=
HTTPStatus
.
NOT_FOUND
)
def
_maybe_get_adapter
(
self
,
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
,
EmbeddingRequest
]
)
->
Tuple
[
Optional
[
str
],
Optional
[
Union
[
LoRARequest
,
PromptAdapterRequest
]]]:
def
_maybe_get_adapters
(
self
,
request
:
AnyRequest
)
->
Union
[
Tuple
[
None
,
None
],
Tuple
[
LoRARequest
,
None
],
Tuple
[
None
,
PromptAdapterRequest
]]:
if
request
.
model
in
self
.
served_model_names
:
return
None
,
None
for
lora
in
self
.
lora_requests
:
if
request
.
model
==
lora
.
lora_name
:
return
'LoRA'
,
lora
return
lora
,
None
for
prompt_adapter
in
self
.
prompt_adapter_requests
:
if
request
.
model
==
prompt_adapter
.
prompt_adapter_name
:
return
'PromptAdapter'
,
prompt_adapter
return
None
,
prompt_adapter
# if _check_model has been called earlier, this will be unreachable
raise
ValueError
(
f
"The model `
{
request
.
model
}
` does not exist."
)
def
_validate_prompt_and_tokenize
(
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
,
DetokenizeRequest
,
EmbeddingRequest
,
TokenizeRequest
],
prompt
:
Optional
[
str
]
=
None
,
prompt_ids
:
Optional
[
List
[
int
]]
=
None
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
,
add_special_tokens
:
Optional
[
bool
]
=
True
)
->
Tuple
[
List
[
int
],
str
]:
if
not
(
prompt
or
prompt_ids
):
raise
ValueError
(
"Either prompt or prompt_ids should be provided."
)
if
(
prompt
and
prompt_ids
):
raise
ValueError
(
"Only one of prompt or prompt_ids should be provided."
)
if
prompt_ids
is
None
:
# When using OpenAIServingChat for chat completions, for
# most models the special tokens (e.g., BOS) have already
# been added by the chat template. Therefore, we do not
# need to add them again.
# Set add_special_tokens to False (by default) to avoid
# adding the BOS tokens again.
tokenizer_kwargs
:
Dict
[
str
,
Any
]
=
{
"add_special_tokens"
:
add_special_tokens
}
if
truncate_prompt_tokens
is
not
None
:
tokenizer_kwargs
.
update
({
"truncation"
:
True
,
"max_length"
:
truncate_prompt_tokens
,
})
input_ids
=
self
.
tokenizer
(
prompt
,
**
tokenizer_kwargs
).
input_ids
elif
truncate_prompt_tokens
is
not
None
:
input_ids
=
prompt_ids
[
-
truncate_prompt_tokens
:]
def
_normalize_prompt_text_to_input
(
self
,
request
:
AnyRequest
,
tokenizer
:
AnyTokenizer
,
prompt
:
str
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]],
add_special_tokens
:
bool
,
)
->
TextTokensPrompt
:
if
truncate_prompt_tokens
is
None
:
encoded
=
tokenizer
(
prompt
,
add_special_tokens
=
add_special_tokens
)
else
:
encoded
=
tokenizer
(
prompt
,
add_special_tokens
=
add_special_tokens
,
truncation
=
True
,
max_length
=
truncate_prompt_tokens
)
input_ids
=
encoded
.
input_ids
input_text
=
prompt
return
self
.
_validate_input
(
request
,
input_ids
,
input_text
)
def
_normalize_prompt_tokens_to_input
(
self
,
request
:
AnyRequest
,
tokenizer
:
AnyTokenizer
,
prompt_ids
:
List
[
int
],
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]],
)
->
TextTokensPrompt
:
if
truncate_prompt_tokens
is
None
:
input_ids
=
prompt_ids
else
:
input_ids
=
prompt_ids
[
-
truncate_prompt_tokens
:]
input_text
=
tokenizer
.
decode
(
input_ids
)
input_text
=
prompt
if
prompt
is
not
None
else
self
.
tokenizer
.
decode
(
prompt_ids
)
return
self
.
_validate_input
(
request
,
input_ids
,
input_text
)
def
_validate_input
(
self
,
request
:
AnyRequest
,
input_ids
:
List
[
int
],
input_text
:
str
,
)
->
TextTokensPrompt
:
token_num
=
len
(
input_ids
)
# Note: EmbeddingRequest doesn't have max_tokens
...
...
@@ -216,13 +236,16 @@ class OpenAIServing:
f
"This model's maximum context length is "
f
"
{
self
.
max_model_len
}
tokens. However, you requested "
f
"
{
token_num
}
tokens in the input for embedding "
f
"generation. Please reduce the length of the input."
,
)
return
input_ids
,
input_text
f
"generation. Please reduce the length of the input."
)
return
TextTokensPrompt
(
prompt
=
input_text
,
prompt_token_ids
=
input_ids
)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if
isinstance
(
request
,
(
TokenizeRequest
,
DetokenizeRequest
)):
return
input_ids
,
input_text
if
isinstance
(
request
,
(
TokenizeCompletionRequest
,
TokenizeChatRequest
,
DetokenizeRequest
)):
return
TextTokensPrompt
(
prompt
=
input_text
,
prompt_token_ids
=
input_ids
)
if
request
.
max_tokens
is
None
:
if
token_num
>=
self
.
max_model_len
:
...
...
@@ -230,7 +253,7 @@ class OpenAIServing:
f
"This model's maximum context length is "
f
"
{
self
.
max_model_len
}
tokens. However, you requested "
f
"
{
token_num
}
tokens in the messages, "
f
"Please reduce the length of the messages."
,
)
f
"Please reduce the length of the messages."
)
request
.
max_tokens
=
self
.
max_model_len
-
token_num
if
token_num
+
request
.
max_tokens
>
self
.
max_model_len
:
...
...
@@ -240,11 +263,132 @@ class OpenAIServing:
f
"
{
request
.
max_tokens
+
token_num
}
tokens "
f
"(
{
token_num
}
in the messages, "
f
"
{
request
.
max_tokens
}
in the completion). "
f
"Please reduce the length of the messages or completion."
,
)
f
"Please reduce the length of the messages or completion."
)
return
TextTokensPrompt
(
prompt
=
input_text
,
prompt_token_ids
=
input_ids
)
def
_tokenize_prompt_input
(
self
,
request
:
AnyRequest
,
tokenizer
:
AnyTokenizer
,
prompt_input
:
Union
[
str
,
List
[
int
]],
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
,
add_special_tokens
:
bool
=
True
,
)
->
TextTokensPrompt
:
"""
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
that assumes single input.
"""
return
next
(
self
.
_tokenize_prompt_inputs
(
request
,
tokenizer
,
[
prompt_input
],
truncate_prompt_tokens
=
truncate_prompt_tokens
,
add_special_tokens
=
add_special_tokens
,
))
def
_tokenize_prompt_inputs
(
self
,
request
:
AnyRequest
,
tokenizer
:
AnyTokenizer
,
prompt_inputs
:
Iterable
[
Union
[
str
,
List
[
int
]]],
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
,
add_special_tokens
:
bool
=
True
,
)
->
Iterator
[
TextTokensPrompt
]:
"""
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
that assumes multiple inputs.
"""
for
text
in
prompt_inputs
:
if
isinstance
(
text
,
str
):
yield
self
.
_normalize_prompt_text_to_input
(
request
,
tokenizer
,
prompt
=
text
,
truncate_prompt_tokens
=
truncate_prompt_tokens
,
add_special_tokens
=
add_special_tokens
,
)
else
:
yield
self
.
_normalize_prompt_tokens_to_input
(
request
,
tokenizer
,
prompt_ids
=
text
,
truncate_prompt_tokens
=
truncate_prompt_tokens
,
)
def
_tokenize_prompt_input_or_inputs
(
self
,
request
:
AnyRequest
,
tokenizer
:
AnyTokenizer
,
input_or_inputs
:
Union
[
str
,
List
[
str
],
List
[
int
],
List
[
List
[
int
]]],
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
,
add_special_tokens
:
bool
=
True
,
)
->
Iterator
[
TextTokensPrompt
]:
"""
Tokenize/detokenize depending on the input format.
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
, each input can be a string or array of tokens. Note that each request
can pass one or more inputs.
"""
for
prompt_input
in
parse_and_batch_prompt
(
input_or_inputs
):
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
if
prompt_input
[
"is_tokens"
]
is
False
:
yield
self
.
_normalize_prompt_text_to_input
(
request
,
tokenizer
,
prompt
=
prompt_input
[
"content"
],
truncate_prompt_tokens
=
truncate_prompt_tokens
,
add_special_tokens
=
add_special_tokens
,
)
else
:
yield
self
.
_normalize_prompt_tokens_to_input
(
request
,
tokenizer
,
prompt_ids
=
prompt_input
[
"content"
],
truncate_prompt_tokens
=
truncate_prompt_tokens
,
)
def
_log_inputs
(
self
,
request_id
:
str
,
inputs
:
Union
[
str
,
List
[
int
],
TextTokensPrompt
],
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]],
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
None
:
if
self
.
request_logger
is
None
:
return
if
isinstance
(
inputs
,
str
):
prompt
=
inputs
prompt_token_ids
=
None
elif
isinstance
(
inputs
,
list
):
prompt
=
None
prompt_token_ids
=
inputs
else
:
return
input_ids
,
input_text
prompt
=
inputs
[
"prompt"
]
prompt_token_ids
=
inputs
[
"prompt_token_ids"
]
self
.
request_logger
.
log_inputs
(
request_id
,
prompt
,
prompt_token_ids
,
params
=
params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
def
_get_decoded_token
(
self
,
logprob
:
Logprob
,
token_id
:
int
)
->
str
:
@
staticmethod
def
_get_decoded_token
(
logprob
:
Logprob
,
token_id
:
int
,
tokenizer
:
AnyTokenizer
,
)
->
str
:
if
logprob
.
decoded_token
is
not
None
:
return
logprob
.
decoded_token
return
self
.
tokenizer
.
decode
(
token_id
)
return
tokenizer
.
decode
(
token_id
)
vllm/entrypoints/openai/serving_tokenization.py
0 → 100644
View file @
500b93c8
from
typing
import
List
,
Optional
,
Union
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
load_chat_template
,
parse_chat_message_content
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
DetokenizeRequest
,
DetokenizeResponse
,
ErrorResponse
,
TokenizeChatRequest
,
TokenizeRequest
,
TokenizeResponse
)
# yapf: enable
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
)
from
vllm.utils
import
random_uuid
class
OpenAIServingTokenization
(
OpenAIServing
):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
request_logger
:
Optional
[
RequestLogger
],
chat_template
:
Optional
[
str
],
):
super
().
__init__
(
engine
=
engine
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
lora_modules
=
lora_modules
,
prompt_adapters
=
None
,
request_logger
=
request_logger
)
# If this is None we use the tokenizer's default chat template
self
.
chat_template
=
load_chat_template
(
chat_template
)
async
def
create_tokenize
(
self
,
request
:
TokenizeRequest
,
)
->
Union
[
TokenizeResponse
,
ErrorResponse
]:
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
request_id
=
f
"tokn-
{
random_uuid
()
}
"
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
if
isinstance
(
request
,
TokenizeChatRequest
):
model_config
=
self
.
model_config
conversation
:
List
[
ConversationMessage
]
=
[]
for
message
in
request
.
messages
:
result
=
parse_chat_message_content
(
message
,
model_config
,
tokenizer
)
conversation
.
extend
(
result
.
messages
)
prompt
=
tokenizer
.
apply_chat_template
(
add_generation_prompt
=
request
.
add_generation_prompt
,
conversation
=
conversation
,
tokenize
=
False
,
chat_template
=
self
.
chat_template
)
assert
isinstance
(
prompt
,
str
)
else
:
prompt
=
request
.
prompt
self
.
_log_inputs
(
request_id
,
prompt
,
params
=
None
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
# Silently ignore prompt adapter since it does not affect tokenization
prompt_input
=
self
.
_tokenize_prompt_input
(
request
,
tokenizer
,
prompt
,
add_special_tokens
=
request
.
add_special_tokens
,
)
input_ids
=
prompt_input
[
"prompt_token_ids"
]
return
TokenizeResponse
(
tokens
=
input_ids
,
count
=
len
(
input_ids
),
max_model_len
=
self
.
max_model_len
)
async
def
create_detokenize
(
self
,
request
:
DetokenizeRequest
,
)
->
Union
[
DetokenizeResponse
,
ErrorResponse
]:
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
request_id
=
f
"tokn-
{
random_uuid
()
}
"
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
self
.
_log_inputs
(
request_id
,
request
.
tokens
,
params
=
None
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
if
prompt_adapter_request
is
not
None
:
raise
NotImplementedError
(
"Prompt adapter is not supported "
"for tokenization"
)
prompt_input
=
self
.
_tokenize_prompt_input
(
request
,
tokenizer
,
request
.
tokens
,
)
input_text
=
prompt_input
[
"prompt"
]
return
DetokenizeResponse
(
prompt
=
input_text
)
vllm/envs.py
View file @
500b93c8
...
...
@@ -17,7 +17,8 @@ if TYPE_CHECKING:
S3_ACCESS_KEY_ID
:
Optional
[
str
]
=
None
S3_SECRET_ACCESS_KEY
:
Optional
[
str
]
=
None
S3_ENDPOINT_URL
:
Optional
[
str
]
=
None
VLLM_CONFIG_ROOT
:
str
=
""
VLLM_CACHE_ROOT
:
str
=
os
.
path
.
expanduser
(
"~/.cache/vllm"
)
VLLM_CONFIG_ROOT
:
str
=
os
.
path
.
expanduser
(
"~/.config/vllm"
)
VLLM_USAGE_STATS_SERVER
:
str
=
"https://stats.vllm.ai"
VLLM_NO_USAGE_STATS
:
bool
=
False
VLLM_DO_NOT_TRACK
:
bool
=
False
...
...
@@ -31,10 +32,12 @@ if TYPE_CHECKING:
VLLM_OPENVINO_KVCACHE_SPACE
:
int
=
0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION
:
Optional
[
str
]
=
None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS
:
bool
=
False
VLLM_XLA_CACHE_PATH
:
str
=
"~/.vllm/
xla_cache
/
"
VLLM_XLA_CACHE_PATH
:
str
=
os
.
path
.
join
(
VLLM_CACHE_ROOT
,
"
xla_cache"
)
VLLM_FUSED_MOE_CHUNK_SIZE
:
int
=
64
*
1024
VLLM_USE_RAY_SPMD_WORKER
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
VLLM_WORKER_MULTIPROC_METHOD
:
str
=
"fork"
VLLM_ASSETS_CACHE
:
str
=
os
.
path
.
join
(
VLLM_CACHE_ROOT
,
"assets"
)
VLLM_IMAGE_FETCH_TIMEOUT
:
int
=
5
VLLM_TARGET_DEVICE
:
str
=
"cuda"
MAX_JOBS
:
Optional
[
str
]
=
None
...
...
@@ -45,6 +48,21 @@ if TYPE_CHECKING:
CMAKE_BUILD_TYPE
:
Optional
[
str
]
=
None
VERBOSE
:
bool
=
False
def
get_default_cache_root
():
return
os
.
getenv
(
"XDG_CACHE_HOME"
,
os
.
path
.
join
(
os
.
path
.
expanduser
(
"~"
),
".cache"
),
)
def
get_default_config_root
():
return
os
.
getenv
(
"XDG_CONFIG_HOME"
,
os
.
path
.
join
(
os
.
path
.
expanduser
(
"~"
),
".config"
),
)
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
...
...
@@ -89,15 +107,28 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
bool
(
int
(
os
.
getenv
(
'VERBOSE'
,
'0'
))),
# Root directory for VLLM configuration files
# Defaults to `~/.config/vllm` unless `XDG_CONFIG_HOME` is set
# Note that this not only affects how vllm finds its configuration files
# during runtime, but also affects how vllm installs its configuration
# files during **installation**.
"VLLM_CONFIG_ROOT"
:
lambda
:
os
.
environ
.
get
(
"VLLM_CONFIG_ROOT"
,
None
)
or
os
.
getenv
(
"XDG_CONFIG_HOME"
,
None
)
or
os
.
path
.
expanduser
(
"~/.config"
),
lambda
:
os
.
path
.
expanduser
(
os
.
getenv
(
"VLLM_CONFIG_ROOT"
,
os
.
path
.
join
(
get_default_config_root
(),
"vllm"
),
)),
# ================== Runtime Env Vars ==================
# Root directory for VLLM cache files
# Defaults to `~/.cache/vllm` unless `XDG_CACHE_HOME` is set
"VLLM_CACHE_ROOT"
:
lambda
:
os
.
path
.
expanduser
(
os
.
getenv
(
"VLLM_CACHE_ROOT"
,
os
.
path
.
join
(
get_default_cache_root
(),
"vllm"
),
)),
# used in distributed environment to determine the master address
'VLLM_HOST_IP'
:
lambda
:
os
.
getenv
(
'VLLM_HOST_IP'
,
""
)
or
os
.
getenv
(
"HOST_IP"
,
""
),
...
...
@@ -231,6 +262,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS"
:
lambda
:
bool
(
os
.
getenv
(
"VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS"
,
False
)),
# If the env var is set, then all workers will execute as separate
# processes from the engine, and we use the same mechanism to trigger
# execution on all workers.
# Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it.
"VLLM_USE_RAY_SPMD_WORKER"
:
lambda
:
bool
(
os
.
getenv
(
"VLLM_USE_RAY_SPMD_WORKER"
,
0
)),
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
...
...
@@ -242,6 +280,14 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_WORKER_MULTIPROC_METHOD"
:
lambda
:
os
.
getenv
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"fork"
),
# Path to the cache for storing downloaded assets
"VLLM_ASSETS_CACHE"
:
lambda
:
os
.
path
.
expanduser
(
os
.
getenv
(
"VLLM_ASSETS_CACHE"
,
os
.
path
.
join
(
get_default_cache_root
(),
"vllm"
,
"assets"
),
)),
# Timeout for fetching images when serving multimodal models
# Default is 5 seconds
"VLLM_IMAGE_FETCH_TIMEOUT"
:
...
...
@@ -250,7 +296,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH"
:
lambda
:
os
.
getenv
(
"VLLM_XLA_CACHE_PATH"
,
"~/.vllm/xla_cache/"
),
lambda
:
os
.
path
.
expanduser
(
os
.
getenv
(
"VLLM_ASSETS_CACHE"
,
os
.
path
.
join
(
get_default_cache_root
(),
"vllm"
,
"xla_cache"
),
)),
"VLLM_FUSED_MOE_CHUNK_SIZE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"65536"
)),
...
...
@@ -262,7 +312,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# end-env-vars-definition
def
__getattr__
(
name
):
def
__getattr__
(
name
:
str
):
# lazy evaluation of environment variables
if
name
in
environment_variables
:
return
environment_variables
[
name
]()
...
...
vllm/executor/cpu_executor.py
View file @
500b93c8
...
...
@@ -17,6 +17,8 @@ logger = init_logger(__name__)
class
CPUExecutor
(
ExecutorBase
):
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
assert
self
.
device_config
.
device_type
==
"cpu"
assert
self
.
lora_config
is
None
,
"cpu backend doesn't support LoRA"
...
...
vllm/executor/distributed_gpu_executor.py
View file @
500b93c8
...
...
@@ -64,8 +64,8 @@ class DistributedGPUExecutor(GPUExecutor):
num_cpu_blocks
=
num_cpu_blocks
)
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
Optional
[
List
[
SamplerOutput
]
]
:
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
if
self
.
parallel_worker_tasks
is
None
:
self
.
parallel_worker_tasks
=
self
.
_run_workers
(
"start_worker_execution_loop"
,
...
...
@@ -73,7 +73,9 @@ class DistributedGPUExecutor(GPUExecutor):
**
self
.
extra_execute_model_run_workers_kwargs
)
# Only the driver worker returns the sampling results.
return
self
.
_driver_execute_model
(
execute_model_req
)
driver_outputs
=
self
.
_driver_execute_model
(
execute_model_req
)
assert
driver_outputs
is
not
None
return
driver_outputs
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
if
self
.
parallel_worker_tasks
is
None
:
...
...
vllm/executor/executor_base.py
View file @
500b93c8
import
asyncio
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
,
Set
,
Tuple
...
...
@@ -19,6 +18,8 @@ class ExecutorBase(ABC):
that can execute the model on multiple devices.
"""
uses_ray
:
bool
# whether the executor uses Ray for orchestration.
def
__init__
(
self
,
model_config
:
ModelConfig
,
...
...
@@ -132,26 +133,6 @@ class ExecutorBase(ABC):
class
ExecutorAsyncBase
(
ExecutorBase
):
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
],
)
->
None
:
self
.
pp_locks
:
Optional
[
List
[
asyncio
.
Lock
]]
=
None
super
().
__init__
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
device_config
,
load_config
,
lora_config
,
multimodal_config
,
speculative_config
,
prompt_adapter_config
)
@
abstractmethod
async
def
execute_model_async
(
self
,
...
...
vllm/executor/gpu_executor.py
View file @
500b93c8
...
...
@@ -12,8 +12,19 @@ from vllm.worker.worker_base import WorkerWrapperBase
logger
=
init_logger
(
__name__
)
def
create_worker
(
worker_module_name
,
worker_class_name
,
**
kwargs
):
wrapper
=
WorkerWrapperBase
(
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
,
)
wrapper
.
init_worker
(
**
kwargs
)
return
wrapper
.
worker
class
GPUExecutor
(
ExecutorBase
):
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
"""Initialize the worker and load the model.
"""
...
...
@@ -51,25 +62,30 @@ class GPUExecutor(ExecutorBase):
or
(
rank
%
self
.
parallel_config
.
tensor_parallel_size
==
0
),
)
def
_get_create_worker_kwargs
(
self
,
local_rank
:
int
=
0
,
rank
:
int
=
0
,
distributed_init_method
:
Optional
[
str
]
=
None
)
->
Dict
:
worker_kwargs
=
self
.
_get_worker_kwargs
(
local_rank
,
rank
,
distributed_init_method
)
if
self
.
speculative_config
is
None
:
worker_kwargs
.
update
(
worker_module_name
=
"vllm.worker.worker"
,
worker_class_name
=
"Worker"
)
else
:
worker_kwargs
.
update
(
worker_module_name
=
"vllm.spec_decode.spec_decode_worker"
,
worker_class_name
=
"create_spec_worker"
)
return
worker_kwargs
def
_create_worker
(
self
,
local_rank
:
int
=
0
,
rank
:
int
=
0
,
distributed_init_method
:
Optional
[
str
]
=
None
):
if
self
.
speculative_config
is
None
:
worker_module_name
=
"vllm.worker.worker"
worker_class_name
=
"Worker"
else
:
worker_module_name
=
"vllm.spec_decode.spec_decode_worker"
worker_class_name
=
"create_spec_worker"
wrapper
=
WorkerWrapperBase
(
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
,
)
wrapper
.
init_worker
(
**
self
.
_get_worker_kwargs
(
local_rank
,
rank
,
distributed_init_method
))
return
wrapper
.
worker
return
create_worker
(
**
self
.
_get_create_worker_kwargs
(
local_rank
=
local_rank
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
))
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of available KV blocks by invoking the
...
...
vllm/executor/multiproc_gpu_executor.py
View file @
500b93c8
import
asyncio
import
os
import
signal
import
weakref
from
functools
import
partial
from
typing
import
Any
,
List
,
Optional
from
vllm.executor.distributed_gpu_executor
import
(
# yapf: disable
DistributedGPUExecutor
,
DistributedGPUExecutorAsync
)
from
vllm.executor.gpu_executor
import
create_worker
from
vllm.executor.multiproc_worker_utils
import
(
ProcessWorkerWrapper
,
ResultHandler
,
WorkerMonitor
)
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.triton_utils
import
maybe_set_triton_cache_manager
from
vllm.utils
import
(
cuda_device_count_stateless
,
from
vllm.utils
import
(
_run_task_with_lock
,
cuda_device_count_stateless
,
error_on_invalid_device_count_status
,
get_distributed_init_method
,
get_open_port
,
get_vllm_instance_id
,
make_async
,
...
...
@@ -22,9 +25,12 @@ logger = init_logger(__name__)
class
MultiprocessingGPUExecutor
(
DistributedGPUExecutor
):
"""Python multiprocessing-based multi-GPU executor"""
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
# Create the parallel GPU workers.
world_size
=
self
.
parallel_config
.
tensor_parallel_size
world_size
=
self
.
parallel_config
.
world_size
tensor_parallel_size
=
self
.
parallel_config
.
tensor_parallel_size
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if
"CUDA_VISIBLE_DEVICES"
not
in
os
.
environ
:
...
...
@@ -47,8 +53,15 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
if
world_size
>
1
:
maybe_set_triton_cache_manager
()
assert
world_size
<=
cuda_device_count_stateless
(),
(
"please set tensor_parallel_size to less than max local gpu count"
)
cuda_device_count
=
cuda_device_count_stateless
()
# Use confusing message for more common TP-only case.
assert
tensor_parallel_size
<=
cuda_device_count
,
(
f
"please set tensor_parallel_size (
{
tensor_parallel_size
}
) "
f
"to less than max local gpu count (
{
cuda_device_count
}
)"
)
assert
world_size
<=
cuda_device_count
,
(
f
"please ensure that world_size (
{
world_size
}
) "
f
"is less than than max local gpu count (
{
cuda_device_count
}
)"
)
error_on_invalid_device_count_status
()
...
...
@@ -58,26 +71,53 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
distributed_init_method
=
get_distributed_init_method
(
"127.0.0.1"
,
get_open_port
())
self
.
workers
:
List
[
ProcessWorkerWrapper
]
=
[]
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self
.
tp_driver_workers
:
List
[
ProcessWorkerWrapper
]
=
[]
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self
.
non_driver_workers
:
List
[
ProcessWorkerWrapper
]
=
[]
if
world_size
==
1
:
self
.
workers
=
[]
self
.
worker_monitor
=
None
else
:
result_handler
=
ResultHandler
()
self
.
workers
=
[
ProcessWorkerWrapper
(
for
rank
in
range
(
1
,
world_size
):
worker
=
ProcessWorkerWrapper
(
result_handler
,
partial
(
self
.
_create_worker
,
rank
=
rank
,
local_rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
))
for
rank
in
range
(
1
,
world_size
)
]
create_worker
,
**
self
.
_get_create_worker_kwargs
(
rank
=
rank
,
local_rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
)))
self
.
workers
.
append
(
worker
)
if
rank
%
tensor_parallel_size
==
0
:
self
.
tp_driver_workers
.
append
(
worker
)
else
:
self
.
non_driver_workers
.
append
(
worker
)
self
.
worker_monitor
=
WorkerMonitor
(
self
.
workers
,
result_handler
)
result_handler
.
start
()
self
.
worker_monitor
.
start
()
# Set up signal handlers to shutdown the executor cleanly
# sometimes gc does not work well
# Use weakref to avoid holding a reference to self
ref
=
weakref
.
ref
(
self
)
def
shutdown
(
signum
,
frame
):
if
executor
:
=
ref
():
executor
.
shutdown
()
signal
.
signal
(
signal
.
SIGINT
,
shutdown
)
signal
.
signal
(
signal
.
SIGTERM
,
shutdown
)
self
.
driver_worker
=
self
.
_create_worker
(
distributed_init_method
=
distributed_init_method
)
self
.
_run_workers
(
"init_device"
)
...
...
@@ -121,16 +161,19 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
raise
NotImplementedError
(
"max_concurrent_workers is not supported yet."
)
# Start the workers first.
if
async_run_tensor_parallel_workers_only
:
# Run only non-driver workers and just return futures.
return
[
worker
.
execute_method
(
method
,
*
args
,
**
kwargs
)
for
worker
in
self
.
non_driver_workers
]
# Start all remote workers first.
worker_outputs
=
[
worker
.
execute_method
(
method
,
*
args
,
**
kwargs
)
for
worker
in
self
.
workers
]
if
async_run_tensor_parallel_workers_only
:
# Just return futures
return
worker_outputs
driver_worker_method
=
getattr
(
self
.
driver_worker
,
method
)
driver_worker_output
=
driver_worker_method
(
*
args
,
**
kwargs
)
...
...
@@ -157,16 +200,45 @@ class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
driver_exec_model
=
make_async
(
self
.
driver_worker
.
execute_model
)
self
.
pp_locks
:
Optional
[
List
[
asyncio
.
Lock
]]
=
None
async
def
_driver_execute_model_async
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
return
await
self
.
driver_exec_model
(
execute_model_req
)
if
not
self
.
tp_driver_workers
:
return
await
self
.
driver_exec_model
(
execute_model_req
)
if
self
.
pp_locks
is
None
:
# This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time
# We create the locks here to avoid creating them in the constructor
# which uses a different asyncio loop.
self
.
pp_locks
=
[
asyncio
.
Lock
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
tasks
=
[
asyncio
.
create_task
(
_run_task_with_lock
(
self
.
driver_exec_model
,
self
.
pp_locks
[
0
],
execute_model_req
))
]
for
pp_rank
,
driver_worker
in
enumerate
(
self
.
tp_driver_workers
,
start
=
1
):
tasks
.
append
(
asyncio
.
create_task
(
_run_task_with_lock
(
driver_worker
.
execute_method_async
,
self
.
pp_locks
[
pp_rank
],
"execute_model"
,
execute_model_req
)))
results
=
await
asyncio
.
gather
(
*
tasks
)
# Only the last PP stage has the final results.
return
results
[
-
1
]
async
def
_start_worker_execution_loop
(
self
):
coros
=
[
worker
.
execute_method_async
(
"start_worker_execution_loop"
)
for
worker
in
self
.
workers
for
worker
in
self
.
non_driver_
workers
]
return
await
asyncio
.
gather
(
*
coros
)
Prev
1
…
5
6
7
8
9
10
11
12
13
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment